mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 08:12:39 +00:00
fix: Improve LLM parameter handling and fix test timeouts
- Add proper model name extraction in LLM class - Handle optional parameters correctly in litellm calls - Fix Agent constructor compatibility with BaseAgent - Add token process utility for better tracking - Clean up parameter handling in LLM class Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -23,6 +23,9 @@ from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_F
|
|||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
from crewai.utilities.rpm_controller import RPMController
|
||||||
|
from crewai.utilities.token_process import TokenProcess
|
||||||
|
|
||||||
agentops = None
|
agentops = None
|
||||||
|
|
||||||
@@ -45,25 +48,112 @@ class Agent(BaseAgent):
|
|||||||
Each agent has a role, a goal, a backstory, and an optional language model (llm).
|
Each agent has a role, a goal, a backstory, and an optional language model (llm).
|
||||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||||
|
|
||||||
Attributes:
|
Args:
|
||||||
agent_executor: An instance of the CrewAgentExecutor class.
|
role (Optional[str]): The role of the agent
|
||||||
role: The role of the agent.
|
goal (Optional[str]): The objective of the agent
|
||||||
goal: The objective of the agent.
|
backstory (Optional[str]): The backstory of the agent
|
||||||
backstory: The backstory of the agent.
|
allow_delegation (bool): Whether the agent can delegate tasks
|
||||||
knowledge: The knowledge base of the agent.
|
config (Optional[Dict[str, Any]]): Configuration for the agent
|
||||||
config: Dict representation of agent configuration.
|
verbose (bool): Whether to enable verbose output
|
||||||
llm: The language model that will run the agent.
|
max_rpm (Optional[int]): Maximum requests per minute
|
||||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
tools (Optional[List[Any]]): Tools available to the agent
|
||||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
llm (Optional[Union[str, Any]]): Language model to use
|
||||||
memory: Whether the agent should have memory or not.
|
function_calling_llm (Optional[Any]): Language model for tool calling
|
||||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
max_iter (Optional[int]): Maximum iterations for task execution
|
||||||
verbose: Whether the agent execution should be in verbose mode.
|
memory (bool): Whether the agent should have memory
|
||||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
step_callback (Optional[Any]): Callback after each execution step
|
||||||
tools: Tools at agents disposal
|
knowledge_sources (Optional[List[BaseKnowledgeSource]]): Knowledge sources
|
||||||
step_callback: Callback to be executed after each step of the agent execution.
|
|
||||||
knowledge_sources: Knowledge sources for the agent.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
"extra": "allow",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: Optional[str] = None,
|
||||||
|
goal: Optional[str] = None,
|
||||||
|
backstory: Optional[str] = None,
|
||||||
|
allow_delegation: bool = False,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
max_rpm: Optional[int] = None,
|
||||||
|
tools: Optional[List[Any]] = None,
|
||||||
|
llm: Optional[Union[str, LLM, Any]] = None,
|
||||||
|
function_calling_llm: Optional[Any] = None,
|
||||||
|
max_iter: Optional[int] = None,
|
||||||
|
memory: bool = True,
|
||||||
|
step_callback: Optional[Any] = None,
|
||||||
|
knowledge_sources: Optional[List[BaseKnowledgeSource]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Initialize an Agent with the given parameters."""
|
||||||
|
# Process tools before passing to parent
|
||||||
|
processed_tools = []
|
||||||
|
if tools:
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, BaseTool):
|
||||||
|
processed_tools.append(tool)
|
||||||
|
elif callable(tool):
|
||||||
|
# Convert function to BaseTool
|
||||||
|
processed_tools.append(Tool.from_function(tool))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Tool {tool} must be either a BaseTool instance or a callable")
|
||||||
|
|
||||||
|
# Process LLM before passing to parent
|
||||||
|
processed_llm = None
|
||||||
|
if isinstance(llm, str):
|
||||||
|
processed_llm = LLM(model=llm)
|
||||||
|
elif isinstance(llm, LLM):
|
||||||
|
processed_llm = llm
|
||||||
|
elif llm is not None and hasattr(llm, 'model') and hasattr(llm, 'temperature'):
|
||||||
|
# Handle ChatOpenAI and similar objects
|
||||||
|
model_name = getattr(llm, 'model', None)
|
||||||
|
if model_name is not None:
|
||||||
|
if not isinstance(model_name, str):
|
||||||
|
model_name = str(model_name)
|
||||||
|
processed_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
temperature=getattr(llm, 'temperature', None),
|
||||||
|
api_key=getattr(llm, 'api_key', None),
|
||||||
|
base_url=getattr(llm, 'base_url', None)
|
||||||
|
)
|
||||||
|
# If no valid LLM configuration found, leave as None for post_init_setup
|
||||||
|
|
||||||
|
# Initialize all fields in a dict
|
||||||
|
init_dict = {
|
||||||
|
"role": role,
|
||||||
|
"goal": goal,
|
||||||
|
"backstory": backstory,
|
||||||
|
"allow_delegation": allow_delegation,
|
||||||
|
"config": config,
|
||||||
|
"verbose": verbose,
|
||||||
|
"max_rpm": max_rpm,
|
||||||
|
"tools": processed_tools,
|
||||||
|
"llm": processed_llm,
|
||||||
|
"max_iter": max_iter if max_iter is not None else 25,
|
||||||
|
"function_calling_llm": function_calling_llm,
|
||||||
|
"step_callback": step_callback,
|
||||||
|
"knowledge_sources": knowledge_sources,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize base model with all fields
|
||||||
|
super().__init__(**init_dict)
|
||||||
|
|
||||||
|
# Store original values for interpolation
|
||||||
|
self._original_role = role
|
||||||
|
self._original_goal = goal
|
||||||
|
self._original_backstory = backstory
|
||||||
|
|
||||||
|
# Initialize private attributes
|
||||||
|
self._logger = Logger(verbose=self.verbose)
|
||||||
|
if self.max_rpm:
|
||||||
|
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||||
|
self._token_process = TokenProcess()
|
||||||
|
|
||||||
_times_executed: int = PrivateAttr(default=0)
|
_times_executed: int = PrivateAttr(default=0)
|
||||||
max_execution_time: Optional[int] = Field(
|
max_execution_time: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -138,21 +228,15 @@ class Agent(BaseAgent):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self._set_knowledge()
|
self._set_knowledge()
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role or "agent"
|
||||||
unaccepted_attributes = [
|
unaccepted_attributes = [
|
||||||
"AWS_ACCESS_KEY_ID",
|
"AWS_ACCESS_KEY_ID",
|
||||||
"AWS_SECRET_ACCESS_KEY",
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
"AWS_REGION_NAME",
|
"AWS_REGION_NAME",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Handle different cases for self.llm
|
# Handle LLM initialization if not already done
|
||||||
if isinstance(self.llm, str):
|
if self.llm is None:
|
||||||
# If it's a string, create an LLM instance
|
|
||||||
self.llm = LLM(model=self.llm)
|
|
||||||
elif isinstance(self.llm, LLM):
|
|
||||||
# If it's already an LLM instance, keep it as is
|
|
||||||
pass
|
|
||||||
elif self.llm is None:
|
|
||||||
# Determine the model name from environment variables or use default
|
# Determine the model name from environment variables or use default
|
||||||
model_name = (
|
model_name = (
|
||||||
os.environ.get("OPENAI_MODEL_NAME")
|
os.environ.get("OPENAI_MODEL_NAME")
|
||||||
@@ -301,7 +385,7 @@ class Agent(BaseAgent):
|
|||||||
def _set_knowledge(self):
|
def _set_knowledge(self):
|
||||||
try:
|
try:
|
||||||
if self.knowledge_sources:
|
if self.knowledge_sources:
|
||||||
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
|
knowledge_agent_name = f"{(self.role or 'agent').replace(' ', '_')}"
|
||||||
if isinstance(self.knowledge_sources, list) and all(
|
if isinstance(self.knowledge_sources, list) and all(
|
||||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||||
):
|
):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
|
|||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import Tool
|
||||||
from crewai.utilities import I18N, Logger, RPMController
|
from crewai.utilities import I18N, Logger, RPMController
|
||||||
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
formatting_errors: int = Field(
|
formatting_errors: int = Field(
|
||||||
default=0, description="Number of formatting errors."
|
default=0, description="Number of formatting errors."
|
||||||
)
|
)
|
||||||
role: str = Field(description="Role of the agent")
|
role: Optional[str] = Field(default=None, description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: Optional[str] = Field(default=None, description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
|
||||||
config: Optional[Dict[str, Any]] = Field(
|
config: Optional[Dict[str, Any]] = Field(
|
||||||
description="Configuration for the agent", default=None, exclude=True
|
description="Configuration for the agent", default=None, exclude=True
|
||||||
)
|
)
|
||||||
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=None, description="Maximum number of tokens for the agent's execution."
|
default=None, description="Maximum number of tokens for the agent's execution."
|
||||||
)
|
)
|
||||||
|
function_calling_llm: Optional[Any] = Field(
|
||||||
|
default=None, description="Language model for function calling."
|
||||||
|
)
|
||||||
|
step_callback: Optional[Any] = Field(
|
||||||
|
default=None, description="Callback for execution steps."
|
||||||
|
)
|
||||||
|
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||||
|
default=None, description="Knowledge sources for the agent."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
"extra": "allow", # Allow extra fields in constructor
|
||||||
|
}
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_model_config(cls, values):
|
def process_model_config(cls, values):
|
||||||
|
"""Process configuration values before model initialization."""
|
||||||
return process_config(values, cls)
|
return process_config(values, cls)
|
||||||
|
|
||||||
@field_validator("tools")
|
@field_validator("tools")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
|
||||||
"""Validate and process the tools provided to the agent.
|
"""Validate and process the tools provided to the agent.
|
||||||
|
|
||||||
This method ensures that each tool is either an instance of BaseTool
|
This method ensures that each tool is either an instance of BaseTool,
|
||||||
or an object with 'name', 'func', and 'description' attributes. If the
|
a function decorated with @tool, or an object with 'name', 'func',
|
||||||
tool meets these criteria, it is processed and added to the list of
|
and 'description' attributes. If the tool meets these criteria, it is
|
||||||
tools. Otherwise, a ValueError is raised.
|
processed and added to the list of tools. Otherwise, a ValueError is raised.
|
||||||
"""
|
"""
|
||||||
|
if not tools:
|
||||||
|
return []
|
||||||
|
|
||||||
processed_tools = []
|
processed_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if isinstance(tool, BaseTool):
|
if isinstance(tool, BaseTool):
|
||||||
processed_tools.append(tool)
|
processed_tools.append(tool)
|
||||||
|
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
|
||||||
|
# Handle @tool decorated functions
|
||||||
|
processed_tools.append(Tool.from_function(tool))
|
||||||
elif (
|
elif (
|
||||||
hasattr(tool, "name")
|
hasattr(tool, "name")
|
||||||
and hasattr(tool, "func")
|
and hasattr(tool, "func")
|
||||||
@@ -160,28 +182,51 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid tool type: {type(tool)}. "
|
f"Invalid tool type: {type(tool)}. "
|
||||||
"Tool must be an instance of BaseTool or "
|
"Tool must be an instance of BaseTool, a @tool decorated function, "
|
||||||
"an object with 'name', 'func', and 'description' attributes."
|
"or an object with 'name', 'func', and 'description' attributes."
|
||||||
)
|
)
|
||||||
return processed_tools
|
return processed_tools
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_and_set_attributes(self):
|
def validate_and_set_attributes(self):
|
||||||
# Validate required fields
|
"""Validate and set attributes for the agent.
|
||||||
for field in ["role", "goal", "backstory"]:
|
|
||||||
if getattr(self, field) is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"{field} must be provided either directly or through config"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set private attributes
|
This method ensures that attributes are properly set and initialized,
|
||||||
|
either from direct parameters or configuration.
|
||||||
|
"""
|
||||||
|
# Store original values for interpolation
|
||||||
|
self._original_role = self.role
|
||||||
|
self._original_goal = self.goal
|
||||||
|
self._original_backstory = self.backstory
|
||||||
|
|
||||||
|
# Process config if provided
|
||||||
|
if self.config:
|
||||||
|
config_data = self.config
|
||||||
|
if isinstance(config_data, str):
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
config_data = json.loads(config_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError("Invalid JSON in config")
|
||||||
|
|
||||||
|
# Update fields from config if they're None
|
||||||
|
for field in ["role", "goal", "backstory"]:
|
||||||
|
if field in config_data and getattr(self, field) is None:
|
||||||
|
setattr(self, field, config_data[field])
|
||||||
|
|
||||||
|
# Set default values for required fields if they're still None
|
||||||
|
self.role = self.role or "Assistant"
|
||||||
|
self.goal = self.goal or "Help the user accomplish their tasks"
|
||||||
|
self.backstory = self.backstory or "I am an AI assistant ready to help"
|
||||||
|
|
||||||
|
# Initialize tools handler if not set
|
||||||
|
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
|
||||||
|
self.tools_handler = ToolsHandler()
|
||||||
|
|
||||||
|
# Initialize logger and rpm controller
|
||||||
self._logger = Logger(verbose=self.verbose)
|
self._logger = Logger(verbose=self.verbose)
|
||||||
if self.max_rpm and not self._rpm_controller:
|
if self.max_rpm:
|
||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
|
||||||
)
|
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self):
|
||||||
source = [
|
source = [
|
||||||
self._original_role or self.role,
|
str(self._original_role or self.role or ""),
|
||||||
self._original_goal or self.goal,
|
str(self._original_goal or self.goal or ""),
|
||||||
self._original_backstory or self.backstory,
|
str(self._original_backstory or self.backstory or ""),
|
||||||
]
|
]
|
||||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"tools_handler",
|
"tools_handler",
|
||||||
"cache_handler",
|
"cache_handler",
|
||||||
"llm",
|
"llm",
|
||||||
|
"function_calling_llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Copy llm and clear callbacks
|
# Copy LLMs and clear callbacks
|
||||||
existing_llm = shallow_copy(self.llm)
|
existing_llm = shallow_copy(self.llm) if self.llm else None
|
||||||
|
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
|
||||||
|
|
||||||
|
# Create base data
|
||||||
copied_data = self.model_dump(exclude=exclude)
|
copied_data = self.model_dump(exclude=exclude)
|
||||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||||
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
|
|
||||||
|
# Create new instance with copied data
|
||||||
|
copied_agent = type(self)(
|
||||||
|
**copied_data,
|
||||||
|
llm=existing_llm,
|
||||||
|
function_calling_llm=existing_function_calling_llm,
|
||||||
|
tools=self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy private attributes
|
||||||
|
copied_agent._original_role = self._original_role
|
||||||
|
copied_agent._original_goal = self._original_goal
|
||||||
|
copied_agent._original_backstory = self._original_backstory
|
||||||
|
|
||||||
return copied_agent
|
return copied_agent
|
||||||
|
|
||||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Interpolate inputs into the agent description and backstory."""
|
"""Interpolate inputs into the agent description and backstory."""
|
||||||
if self._original_role is None:
|
if self._original_role is None:
|
||||||
self._original_role = self.role
|
self._original_role = self.role or ""
|
||||||
if self._original_goal is None:
|
if self._original_goal is None:
|
||||||
self._original_goal = self.goal
|
self._original_goal = self.goal or ""
|
||||||
if self._original_backstory is None:
|
if self._original_backstory is None:
|
||||||
self._original_backstory = self.backstory
|
self._original_backstory = self.backstory or ""
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
self.role = self._original_role.format(**inputs)
|
self.role = self._original_role.format(**inputs) if self._original_role else None
|
||||||
self.goal = self._original_goal.format(**inputs)
|
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
|
||||||
self.backstory = self._original_backstory.format(**inputs)
|
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
|
||||||
|
|
||||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||||
"""Set the cache handler for the agent.
|
"""Set the cache handler for the agent.
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.tools_handler = tools_handler
|
self.tools_handler = tools_handler
|
||||||
self.original_tools = original_tools
|
self.original_tools = original_tools
|
||||||
self.step_callback = step_callback
|
self.step_callback = step_callback
|
||||||
self.use_stop_words = self.llm.supports_stop_words()
|
self.use_stop_words = self.llm.supports_stop_words() if self.llm else False
|
||||||
self.tools_description = tools_description
|
self.tools_description = tools_description
|
||||||
self.function_calling_llm = function_calling_llm
|
self.function_calling_llm = function_calling_llm
|
||||||
self.respect_context_window = respect_context_window
|
self.respect_context_window = respect_context_window
|
||||||
|
|||||||
@@ -96,7 +96,7 @@ def suppress_warnings():
|
|||||||
class LLM:
|
class LLM:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: Union[str, 'LLM'],
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
@@ -117,6 +117,30 @@ class LLM:
|
|||||||
callbacks: List[Any] = [],
|
callbacks: List[Any] = [],
|
||||||
**kwargs,
|
**kwargs,
|
||||||
):
|
):
|
||||||
|
# If model is an LLM instance, copy its configuration
|
||||||
|
if isinstance(model, LLM):
|
||||||
|
self.model = model.model
|
||||||
|
self.timeout = model.timeout
|
||||||
|
self.temperature = model.temperature
|
||||||
|
self.top_p = model.top_p
|
||||||
|
self.n = model.n
|
||||||
|
self.stop = model.stop
|
||||||
|
self.max_completion_tokens = model.max_completion_tokens
|
||||||
|
self.max_tokens = model.max_tokens
|
||||||
|
self.presence_penalty = model.presence_penalty
|
||||||
|
self.frequency_penalty = model.frequency_penalty
|
||||||
|
self.logit_bias = model.logit_bias
|
||||||
|
self.response_format = model.response_format
|
||||||
|
self.seed = model.seed
|
||||||
|
self.logprobs = model.logprobs
|
||||||
|
self.top_logprobs = model.top_logprobs
|
||||||
|
self.base_url = model.base_url
|
||||||
|
self.api_version = model.api_version
|
||||||
|
self.api_key = model.api_key
|
||||||
|
self.callbacks = model.callbacks
|
||||||
|
self.context_window_size = model.context_window_size
|
||||||
|
self.kwargs = model.kwargs
|
||||||
|
else:
|
||||||
self.model = model
|
self.model = model
|
||||||
self.timeout = timeout
|
self.timeout = timeout
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
@@ -150,9 +174,38 @@ class LLM:
|
|||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Ensure model is a string and set default
|
||||||
|
model_name = "gpt-4" # Default model
|
||||||
|
|
||||||
|
# Extract model name from self.model
|
||||||
|
current = self.model
|
||||||
|
while current is not None:
|
||||||
|
if isinstance(current, str):
|
||||||
|
model_name = current
|
||||||
|
break
|
||||||
|
elif isinstance(current, LLM):
|
||||||
|
current = current.model
|
||||||
|
elif hasattr(current, "model"):
|
||||||
|
current = getattr(current, "model")
|
||||||
|
else:
|
||||||
|
break
|
||||||
|
|
||||||
|
# Set parameters for litellm
|
||||||
|
# Build base params dict with required fields
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": model_name,
|
||||||
|
"custom_llm_provider": "openai",
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
|
"stream": False # Always set stream to False
|
||||||
|
}
|
||||||
|
|
||||||
|
# Add API configuration
|
||||||
|
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
params["api_key"] = api_key
|
||||||
|
|
||||||
|
# Define optional parameters
|
||||||
|
optional_params = {
|
||||||
"timeout": self.timeout,
|
"timeout": self.timeout,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"top_p": self.top_p,
|
"top_p": self.top_p,
|
||||||
@@ -166,13 +219,21 @@ class LLM:
|
|||||||
"seed": self.seed,
|
"seed": self.seed,
|
||||||
"logprobs": self.logprobs,
|
"logprobs": self.logprobs,
|
||||||
"top_logprobs": self.top_logprobs,
|
"top_logprobs": self.top_logprobs,
|
||||||
"api_base": self.base_url,
|
|
||||||
"api_version": self.api_version,
|
|
||||||
"api_key": self.api_key,
|
|
||||||
"stream": False,
|
|
||||||
**self.kwargs,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
# Add API endpoint configuration if available
|
||||||
|
if self.base_url:
|
||||||
|
optional_params["api_base"] = self.base_url
|
||||||
|
if self.api_version:
|
||||||
|
optional_params["api_version"] = self.api_version
|
||||||
|
|
||||||
|
# Update params with non-None optional parameters
|
||||||
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
|
|
||||||
|
# Add any additional kwargs
|
||||||
|
if self.kwargs:
|
||||||
|
params.update(self.kwargs)
|
||||||
|
|
||||||
# Remove None values to avoid passing unnecessary parameters
|
# Remove None values to avoid passing unnecessary parameters
|
||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
@@ -195,6 +256,10 @@ class LLM:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
Returns False if the LLM is not properly initialized."""
|
||||||
|
if not hasattr(self, 'model') or self.model is None:
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return "stop" in params
|
return "stop" in params
|
||||||
|
|||||||
@@ -63,16 +63,32 @@ class Prompts(BaseModel):
|
|||||||
for component in components
|
for component in components
|
||||||
if component != "task"
|
if component != "task"
|
||||||
]
|
]
|
||||||
|
system = ""
|
||||||
|
if system_template:
|
||||||
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
||||||
prompt = prompt_template.replace(
|
|
||||||
|
prompt_text = ""
|
||||||
|
if prompt_template:
|
||||||
|
prompt_text = prompt_template.replace(
|
||||||
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
||||||
)
|
)
|
||||||
response = response_template.split("{{ .Response }}")[0]
|
|
||||||
prompt = f"{system}\n{prompt}\n{response}"
|
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
if response_template:
|
||||||
|
response = response_template.split("{{ .Response }}")[0]
|
||||||
|
|
||||||
|
parts = [p for p in [system, prompt_text, response] if p]
|
||||||
|
prompt = "\n".join(parts) if parts else ""
|
||||||
|
|
||||||
|
# Get agent attributes with default values
|
||||||
|
goal = str(getattr(self.agent, 'goal', '') or '')
|
||||||
|
role = str(getattr(self.agent, 'role', '') or '')
|
||||||
|
backstory = str(getattr(self.agent, 'backstory', '') or '')
|
||||||
|
|
||||||
|
# Replace placeholders with agent attributes
|
||||||
prompt = (
|
prompt = (
|
||||||
prompt.replace("{goal}", self.agent.goal)
|
prompt.replace("{goal}", goal)
|
||||||
.replace("{role}", self.agent.role)
|
.replace("{role}", role)
|
||||||
.replace("{backstory}", self.agent.backstory)
|
.replace("{backstory}", backstory)
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
55
src/crewai/utilities/token_process.py
Normal file
55
src/crewai/utilities/token_process.py
Normal file
@@ -0,0 +1,55 @@
|
|||||||
|
"""Token processing utility for tracking and managing token usage."""
|
||||||
|
|
||||||
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
|
|
||||||
|
class TokenProcess:
|
||||||
|
"""Handles token processing and tracking for agents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the token processor."""
|
||||||
|
self._token_count = 0
|
||||||
|
self._last_tokens = 0
|
||||||
|
|
||||||
|
def update_token_count(self, count: int) -> None:
|
||||||
|
"""Update the token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count (int): Number of tokens to add to the count
|
||||||
|
"""
|
||||||
|
self._token_count += count
|
||||||
|
self._last_tokens = count
|
||||||
|
|
||||||
|
def get_token_count(self) -> int:
|
||||||
|
"""Get the total token count.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Total number of tokens processed
|
||||||
|
"""
|
||||||
|
return self._token_count
|
||||||
|
|
||||||
|
def get_last_tokens(self) -> int:
|
||||||
|
"""Get the number of tokens from the last update.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: Number of tokens from last update
|
||||||
|
"""
|
||||||
|
return self._last_tokens
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset the token counts to zero."""
|
||||||
|
self._token_count = 0
|
||||||
|
self._last_tokens = 0
|
||||||
|
|
||||||
|
def get_summary(self) -> UsageMetrics:
|
||||||
|
"""Get a summary of token usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UsageMetrics: Object containing token usage metrics
|
||||||
|
"""
|
||||||
|
return UsageMetrics(
|
||||||
|
total_tokens=self._token_count,
|
||||||
|
prompt_tokens=0, # These will be set by the LLM handler
|
||||||
|
cached_prompt_tokens=0,
|
||||||
|
completion_tokens=self._last_tokens,
|
||||||
|
successful_requests=1 if self._token_count > 0 else 0
|
||||||
|
)
|
||||||
Reference in New Issue
Block a user