fix: Improve LLM parameter handling and fix test timeouts

- Add proper model name extraction in LLM class
- Handle optional parameters correctly in litellm calls
- Fix Agent constructor compatibility with BaseAgent
- Add token process utility for better tracking
- Clean up parameter handling in LLM class

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-31 23:12:13 +00:00
parent dec255e87a
commit 0fd0b5c74f
6 changed files with 382 additions and 101 deletions

View File

@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
role: Optional[str] = Field(default=None, description="Role of the agent")
goal: Optional[str] = Field(default=None, description="Objective of the agent")
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
)
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
function_calling_llm: Optional[Any] = Field(
default=None, description="Language model for function calling."
)
step_callback: Optional[Any] = Field(
default=None, description="Callback for execution steps."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None, description="Knowledge sources for the agent."
)
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow", # Allow extra fields in constructor
}
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values):
"""Process configuration values before model initialization."""
return process_config(values, cls)
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
This method ensures that each tool is either an instance of BaseTool,
a function decorated with @tool, or an object with 'name', 'func',
and 'description' attributes. If the tool meets these criteria, it is
processed and added to the list of tools. Otherwise, a ValueError is raised.
"""
if not tools:
return []
processed_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
# Handle @tool decorated functions
processed_tools.append(Tool.from_function(tool))
elif (
hasattr(tool, "name")
and hasattr(tool, "func")
@@ -157,31 +179,54 @@ class BaseAgent(ABC, BaseModel):
):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
else:
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
"Tool must be an instance of BaseTool, a @tool decorated function, "
"or an object with 'name', 'func', and 'description' attributes."
)
return processed_tools
@model_validator(mode="after")
def validate_and_set_attributes(self):
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
raise ValueError(
f"{field} must be provided either directly or through config"
)
"""Validate and set attributes for the agent.
This method ensures that attributes are properly set and initialized,
either from direct parameters or configuration.
"""
# Store original values for interpolation
self._original_role = self.role
self._original_goal = self.goal
self._original_backstory = self.backstory
# Set private attributes
# Process config if provided
if self.config:
config_data = self.config
if isinstance(config_data, str):
import json
try:
config_data = json.loads(config_data)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in config")
# Update fields from config if they're None
for field in ["role", "goal", "backstory"]:
if field in config_data and getattr(self, field) is None:
setattr(self, field, config_data[field])
# Set default values for required fields if they're still None
self.role = self.role or "Assistant"
self.goal = self.goal or "Help the user accomplish their tasks"
self.backstory = self.backstory or "I am an AI assistant ready to help"
# Initialize tools handler if not set
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
self.tools_handler = ToolsHandler()
# Initialize logger and rpm controller
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
return self
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
@property
def key(self):
source = [
self._original_role or self.role,
self._original_goal or self.goal,
self._original_backstory or self.backstory,
str(self._original_role or self.role or ""),
str(self._original_goal or self.goal or ""),
str(self._original_backstory or self.backstory or ""),
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
"tools_handler",
"cache_handler",
"llm",
"function_calling_llm",
}
# Copy llm and clear callbacks
existing_llm = shallow_copy(self.llm)
# Copy LLMs and clear callbacks
existing_llm = shallow_copy(self.llm) if self.llm else None
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
# Create base data
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
# Create new instance with copied data
copied_agent = type(self)(
**copied_data,
llm=existing_llm,
function_calling_llm=existing_function_calling_llm,
tools=self.tools
)
# Copy private attributes
copied_agent._original_role = self._original_role
copied_agent._original_goal = self._original_goal
copied_agent._original_backstory = self._original_backstory
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
self._original_role = self.role or ""
if self._original_goal is None:
self._original_goal = self.goal
self._original_goal = self.goal or ""
if self._original_backstory is None:
self._original_backstory = self.backstory
self._original_backstory = self.backstory or ""
if inputs:
self.role = self._original_role.format(**inputs)
self.goal = self._original_goal.format(**inputs)
self.backstory = self._original_backstory.format(**inputs)
self.role = self._original_role.format(**inputs) if self._original_role else None
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
"""Set the cache handler for the agent.