fix: Improve LLM parameter handling and fix test timeouts

- Add proper model name extraction in LLM class
- Handle optional parameters correctly in litellm calls
- Fix Agent constructor compatibility with BaseAgent
- Add token process utility for better tracking
- Clean up parameter handling in LLM class

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-31 23:12:13 +00:00
parent dec255e87a
commit 0fd0b5c74f
6 changed files with 382 additions and 101 deletions

View File

@@ -23,6 +23,9 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
from crewai.utilities.converter import generate_model_description
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
from crewai.utilities.logger import Logger
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.token_process import TokenProcess
agentops = None
@@ -45,24 +48,111 @@ class Agent(BaseAgent):
Each agent has a role, a goal, a backstory, and an optional language model (llm).
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
Attributes:
agent_executor: An instance of the CrewAgentExecutor class.
role: The role of the agent.
goal: The objective of the agent.
backstory: The backstory of the agent.
knowledge: The knowledge base of the agent.
config: Dict representation of agent configuration.
llm: The language model that will run the agent.
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
max_iter: Maximum number of iterations for an agent to execute a task.
memory: Whether the agent should have memory or not.
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
verbose: Whether the agent execution should be in verbose mode.
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
Args:
role (Optional[str]): The role of the agent
goal (Optional[str]): The objective of the agent
backstory (Optional[str]): The backstory of the agent
allow_delegation (bool): Whether the agent can delegate tasks
config (Optional[Dict[str, Any]]): Configuration for the agent
verbose (bool): Whether to enable verbose output
max_rpm (Optional[int]): Maximum requests per minute
tools (Optional[List[Any]]): Tools available to the agent
llm (Optional[Union[str, Any]]): Language model to use
function_calling_llm (Optional[Any]): Language model for tool calling
max_iter (Optional[int]): Maximum iterations for task execution
memory (bool): Whether the agent should have memory
step_callback (Optional[Any]): Callback after each execution step
knowledge_sources (Optional[List[BaseKnowledgeSource]]): Knowledge sources
"""
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow",
}
def __init__(
self,
role: Optional[str] = None,
goal: Optional[str] = None,
backstory: Optional[str] = None,
allow_delegation: bool = False,
config: Optional[Dict[str, Any]] = None,
verbose: bool = False,
max_rpm: Optional[int] = None,
tools: Optional[List[Any]] = None,
llm: Optional[Union[str, LLM, Any]] = None,
function_calling_llm: Optional[Any] = None,
max_iter: Optional[int] = None,
memory: bool = True,
step_callback: Optional[Any] = None,
knowledge_sources: Optional[List[BaseKnowledgeSource]] = None,
**kwargs
) -> None:
"""Initialize an Agent with the given parameters."""
# Process tools before passing to parent
processed_tools = []
if tools:
from crewai.tools import BaseTool
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif callable(tool):
# Convert function to BaseTool
processed_tools.append(Tool.from_function(tool))
else:
raise ValueError(f"Tool {tool} must be either a BaseTool instance or a callable")
# Process LLM before passing to parent
processed_llm = None
if isinstance(llm, str):
processed_llm = LLM(model=llm)
elif isinstance(llm, LLM):
processed_llm = llm
elif llm is not None and hasattr(llm, 'model') and hasattr(llm, 'temperature'):
# Handle ChatOpenAI and similar objects
model_name = getattr(llm, 'model', None)
if model_name is not None:
if not isinstance(model_name, str):
model_name = str(model_name)
processed_llm = LLM(
model=model_name,
temperature=getattr(llm, 'temperature', None),
api_key=getattr(llm, 'api_key', None),
base_url=getattr(llm, 'base_url', None)
)
# If no valid LLM configuration found, leave as None for post_init_setup
# Initialize all fields in a dict
init_dict = {
"role": role,
"goal": goal,
"backstory": backstory,
"allow_delegation": allow_delegation,
"config": config,
"verbose": verbose,
"max_rpm": max_rpm,
"tools": processed_tools,
"llm": processed_llm,
"max_iter": max_iter if max_iter is not None else 25,
"function_calling_llm": function_calling_llm,
"step_callback": step_callback,
"knowledge_sources": knowledge_sources,
**kwargs
}
# Initialize base model with all fields
super().__init__(**init_dict)
# Store original values for interpolation
self._original_role = role
self._original_goal = goal
self._original_backstory = backstory
# Initialize private attributes
self._logger = Logger(verbose=self.verbose)
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self._token_process = TokenProcess()
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
@@ -138,21 +228,15 @@ class Agent(BaseAgent):
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
self.agent_ops_agent_name = self.role or "agent"
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
# Handle different cases for self.llm
if isinstance(self.llm, str):
# If it's a string, create an LLM instance
self.llm = LLM(model=self.llm)
elif isinstance(self.llm, LLM):
# If it's already an LLM instance, keep it as is
pass
elif self.llm is None:
# Handle LLM initialization if not already done
if self.llm is None:
# Determine the model name from environment variables or use default
model_name = (
os.environ.get("OPENAI_MODEL_NAME")
@@ -301,7 +385,7 @@ class Agent(BaseAgent):
def _set_knowledge(self):
try:
if self.knowledge_sources:
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
knowledge_agent_name = f"{(self.role or 'agent').replace(' ', '_')}"
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):

View File

@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
role: Optional[str] = Field(default=None, description="Role of the agent")
goal: Optional[str] = Field(default=None, description="Objective of the agent")
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
)
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
function_calling_llm: Optional[Any] = Field(
default=None, description="Language model for function calling."
)
step_callback: Optional[Any] = Field(
default=None, description="Callback for execution steps."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None, description="Knowledge sources for the agent."
)
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow", # Allow extra fields in constructor
}
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values):
"""Process configuration values before model initialization."""
return process_config(values, cls)
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
This method ensures that each tool is either an instance of BaseTool,
a function decorated with @tool, or an object with 'name', 'func',
and 'description' attributes. If the tool meets these criteria, it is
processed and added to the list of tools. Otherwise, a ValueError is raised.
"""
if not tools:
return []
processed_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
# Handle @tool decorated functions
processed_tools.append(Tool.from_function(tool))
elif (
hasattr(tool, "name")
and hasattr(tool, "func")
@@ -157,31 +179,54 @@ class BaseAgent(ABC, BaseModel):
):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
else:
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
"Tool must be an instance of BaseTool, a @tool decorated function, "
"or an object with 'name', 'func', and 'description' attributes."
)
return processed_tools
@model_validator(mode="after")
def validate_and_set_attributes(self):
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
raise ValueError(
f"{field} must be provided either directly or through config"
)
"""Validate and set attributes for the agent.
This method ensures that attributes are properly set and initialized,
either from direct parameters or configuration.
"""
# Store original values for interpolation
self._original_role = self.role
self._original_goal = self.goal
self._original_backstory = self.backstory
# Set private attributes
# Process config if provided
if self.config:
config_data = self.config
if isinstance(config_data, str):
import json
try:
config_data = json.loads(config_data)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in config")
# Update fields from config if they're None
for field in ["role", "goal", "backstory"]:
if field in config_data and getattr(self, field) is None:
setattr(self, field, config_data[field])
# Set default values for required fields if they're still None
self.role = self.role or "Assistant"
self.goal = self.goal or "Help the user accomplish their tasks"
self.backstory = self.backstory or "I am an AI assistant ready to help"
# Initialize tools handler if not set
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
self.tools_handler = ToolsHandler()
# Initialize logger and rpm controller
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
return self
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
@property
def key(self):
source = [
self._original_role or self.role,
self._original_goal or self.goal,
self._original_backstory or self.backstory,
str(self._original_role or self.role or ""),
str(self._original_goal or self.goal or ""),
str(self._original_backstory or self.backstory or ""),
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
"tools_handler",
"cache_handler",
"llm",
"function_calling_llm",
}
# Copy llm and clear callbacks
existing_llm = shallow_copy(self.llm)
# Copy LLMs and clear callbacks
existing_llm = shallow_copy(self.llm) if self.llm else None
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
# Create base data
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
# Create new instance with copied data
copied_agent = type(self)(
**copied_data,
llm=existing_llm,
function_calling_llm=existing_function_calling_llm,
tools=self.tools
)
# Copy private attributes
copied_agent._original_role = self._original_role
copied_agent._original_goal = self._original_goal
copied_agent._original_backstory = self._original_backstory
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
self._original_role = self.role or ""
if self._original_goal is None:
self._original_goal = self.goal
self._original_goal = self.goal or ""
if self._original_backstory is None:
self._original_backstory = self.backstory
self._original_backstory = self.backstory or ""
if inputs:
self.role = self._original_role.format(**inputs)
self.goal = self._original_goal.format(**inputs)
self.backstory = self._original_backstory.format(**inputs)
self.role = self._original_role.format(**inputs) if self._original_role else None
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
"""Set the cache handler for the agent.

View File

@@ -68,7 +68,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.tools_handler = tools_handler
self.original_tools = original_tools
self.step_callback = step_callback
self.use_stop_words = self.llm.supports_stop_words()
self.use_stop_words = self.llm.supports_stop_words() if self.llm else False
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window

View File

@@ -96,7 +96,7 @@ def suppress_warnings():
class LLM:
def __init__(
self,
model: str,
model: Union[str, 'LLM'],
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@@ -117,27 +117,51 @@ class LLM:
callbacks: List[Any] = [],
**kwargs,
):
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.stop = stop
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.kwargs = kwargs
# If model is an LLM instance, copy its configuration
if isinstance(model, LLM):
self.model = model.model
self.timeout = model.timeout
self.temperature = model.temperature
self.top_p = model.top_p
self.n = model.n
self.stop = model.stop
self.max_completion_tokens = model.max_completion_tokens
self.max_tokens = model.max_tokens
self.presence_penalty = model.presence_penalty
self.frequency_penalty = model.frequency_penalty
self.logit_bias = model.logit_bias
self.response_format = model.response_format
self.seed = model.seed
self.logprobs = model.logprobs
self.top_logprobs = model.top_logprobs
self.base_url = model.base_url
self.api_version = model.api_version
self.api_key = model.api_key
self.callbacks = model.callbacks
self.context_window_size = model.context_window_size
self.kwargs = model.kwargs
else:
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.stop = stop
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.kwargs = kwargs
litellm.drop_params = True
@@ -150,9 +174,38 @@ class LLM:
self.set_callbacks(callbacks)
try:
# Ensure model is a string and set default
model_name = "gpt-4" # Default model
# Extract model name from self.model
current = self.model
while current is not None:
if isinstance(current, str):
model_name = current
break
elif isinstance(current, LLM):
current = current.model
elif hasattr(current, "model"):
current = getattr(current, "model")
else:
break
# Set parameters for litellm
# Build base params dict with required fields
params = {
"model": self.model,
"model": model_name,
"custom_llm_provider": "openai",
"messages": messages,
"stream": False # Always set stream to False
}
# Add API configuration
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key:
params["api_key"] = api_key
# Define optional parameters
optional_params = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
@@ -166,12 +219,20 @@ class LLM:
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"api_base": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
**self.kwargs,
}
# Add API endpoint configuration if available
if self.base_url:
optional_params["api_base"] = self.base_url
if self.api_version:
optional_params["api_version"] = self.api_version
# Update params with non-None optional parameters
params.update({k: v for k, v in optional_params.items() if v is not None})
# Add any additional kwargs
if self.kwargs:
params.update(self.kwargs)
# Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None}
@@ -195,6 +256,10 @@ class LLM:
return False
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns False if the LLM is not properly initialized."""
if not hasattr(self, 'model') or self.model is None:
return False
try:
params = get_supported_openai_params(model=self.model)
return "stop" in params

View File

@@ -63,16 +63,32 @@ class Prompts(BaseModel):
for component in components
if component != "task"
]
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
prompt = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
)
response = response_template.split("{{ .Response }}")[0]
prompt = f"{system}\n{prompt}\n{response}"
system = ""
if system_template:
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
prompt_text = ""
if prompt_template:
prompt_text = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
)
response = ""
if response_template:
response = response_template.split("{{ .Response }}")[0]
parts = [p for p in [system, prompt_text, response] if p]
prompt = "\n".join(parts) if parts else ""
# Get agent attributes with default values
goal = str(getattr(self.agent, 'goal', '') or '')
role = str(getattr(self.agent, 'role', '') or '')
backstory = str(getattr(self.agent, 'backstory', '') or '')
# Replace placeholders with agent attributes
prompt = (
prompt.replace("{goal}", self.agent.goal)
.replace("{role}", self.agent.role)
.replace("{backstory}", self.agent.backstory)
prompt.replace("{goal}", goal)
.replace("{role}", role)
.replace("{backstory}", backstory)
)
return prompt

View File

@@ -0,0 +1,55 @@
"""Token processing utility for tracking and managing token usage."""
from crewai.types.usage_metrics import UsageMetrics
class TokenProcess:
"""Handles token processing and tracking for agents."""
def __init__(self):
"""Initialize the token processor."""
self._token_count = 0
self._last_tokens = 0
def update_token_count(self, count: int) -> None:
"""Update the token count.
Args:
count (int): Number of tokens to add to the count
"""
self._token_count += count
self._last_tokens = count
def get_token_count(self) -> int:
"""Get the total token count.
Returns:
int: Total number of tokens processed
"""
return self._token_count
def get_last_tokens(self) -> int:
"""Get the number of tokens from the last update.
Returns:
int: Number of tokens from last update
"""
return self._last_tokens
def reset(self) -> None:
"""Reset the token counts to zero."""
self._token_count = 0
self._last_tokens = 0
def get_summary(self) -> UsageMetrics:
"""Get a summary of token usage.
Returns:
UsageMetrics: Object containing token usage metrics
"""
return UsageMetrics(
total_tokens=self._token_count,
prompt_tokens=0, # These will be set by the LLM handler
cached_prompt_tokens=0,
completion_tokens=self._last_tokens,
successful_requests=1 if self._token_count > 0 else 0
)