mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-20 05:18:16 +00:00
Compare commits
12 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfb578d506 | ||
|
|
5d3c34b3ea | ||
|
|
8ec2eb7d72 | ||
|
|
39bdc7e4d4 | ||
|
|
344fa9bbe5 | ||
|
|
0fd0b5c74f | ||
|
|
090d5128cb | ||
|
|
dec255e87a | ||
|
|
f75b07ce82 | ||
|
|
cf2f21cbfb | ||
|
|
c4a401b247 | ||
|
|
b0d545992a |
@@ -1,3 +1,5 @@
|
|||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
import os
|
import os
|
||||||
import shutil
|
import shutil
|
||||||
import subprocess
|
import subprocess
|
||||||
@@ -21,6 +23,9 @@ from crewai.tools.base_tool import Tool
|
|||||||
from crewai.utilities import Converter, Prompts
|
from crewai.utilities import Converter, Prompts
|
||||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||||
from crewai.utilities.converter import generate_model_description
|
from crewai.utilities.converter import generate_model_description
|
||||||
|
from crewai.utilities.logger import Logger
|
||||||
|
from crewai.utilities.rpm_controller import RPMController
|
||||||
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||||
|
|
||||||
@@ -45,25 +50,114 @@ class Agent(BaseAgent):
|
|||||||
Each agent has a role, a goal, a backstory, and an optional language model (llm).
|
Each agent has a role, a goal, a backstory, and an optional language model (llm).
|
||||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||||
|
|
||||||
Attributes:
|
Args:
|
||||||
agent_executor: An instance of the CrewAgentExecutor class.
|
role (Optional[str]): The role of the agent
|
||||||
role: The role of the agent.
|
goal (Optional[str]): The objective of the agent
|
||||||
goal: The objective of the agent.
|
backstory (Optional[str]): The backstory of the agent
|
||||||
backstory: The backstory of the agent.
|
allow_delegation (bool): Whether the agent can delegate tasks
|
||||||
knowledge: The knowledge base of the agent.
|
config (Optional[Dict[str, Any]]): Configuration for the agent
|
||||||
config: Dict representation of agent configuration.
|
verbose (bool): Whether to enable verbose output
|
||||||
llm: The language model that will run the agent.
|
max_rpm (Optional[int]): Maximum requests per minute
|
||||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
tools (Optional[List[Any]]): Tools available to the agent
|
||||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
llm (Optional[Union[str, Any]]): Language model to use
|
||||||
memory: Whether the agent should have memory or not.
|
function_calling_llm (Optional[Any]): Language model for tool calling
|
||||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
max_iter (Optional[int]): Maximum iterations for task execution
|
||||||
verbose: Whether the agent execution should be in verbose mode.
|
memory (bool): Whether the agent should have memory
|
||||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
step_callback (Optional[Any]): Callback after each execution step
|
||||||
tools: Tools at agents disposal
|
knowledge_sources (Optional[List[BaseKnowledgeSource]]): Knowledge sources
|
||||||
step_callback: Callback to be executed after each step of the agent execution.
|
|
||||||
knowledge_sources: Knowledge sources for the agent.
|
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
"extra": "allow",
|
||||||
|
}
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
role: Optional[str] = None,
|
||||||
|
goal: Optional[str] = None,
|
||||||
|
backstory: Optional[str] = None,
|
||||||
|
allow_delegation: bool = False,
|
||||||
|
config: Optional[Dict[str, Any]] = None,
|
||||||
|
verbose: bool = False,
|
||||||
|
max_rpm: Optional[int] = None,
|
||||||
|
tools: Optional[List[Any]] = None,
|
||||||
|
llm: Optional[Union[str, LLM, Any]] = None,
|
||||||
|
function_calling_llm: Optional[Any] = None,
|
||||||
|
max_iter: Optional[int] = None,
|
||||||
|
memory: bool = True,
|
||||||
|
step_callback: Optional[Any] = None,
|
||||||
|
knowledge_sources: Optional[List[BaseKnowledgeSource]] = None,
|
||||||
|
**kwargs
|
||||||
|
) -> None:
|
||||||
|
"""Initialize an Agent with the given parameters."""
|
||||||
|
# Process tools before passing to parent
|
||||||
|
processed_tools = []
|
||||||
|
if tools:
|
||||||
|
from crewai.tools import BaseTool
|
||||||
|
for tool in tools:
|
||||||
|
if isinstance(tool, BaseTool):
|
||||||
|
processed_tools.append(tool)
|
||||||
|
elif callable(tool):
|
||||||
|
# Convert function to BaseTool
|
||||||
|
processed_tools.append(Tool.from_function(tool))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Tool {tool} must be either a BaseTool instance or a callable")
|
||||||
|
|
||||||
|
# Process LLM before passing to parent
|
||||||
|
processed_llm = None
|
||||||
|
if isinstance(llm, str):
|
||||||
|
processed_llm = LLM(model=llm)
|
||||||
|
elif isinstance(llm, LLM):
|
||||||
|
processed_llm = llm
|
||||||
|
elif llm is not None and hasattr(llm, 'model') and hasattr(llm, 'temperature'):
|
||||||
|
# Handle ChatOpenAI and similar objects
|
||||||
|
model_name = getattr(llm, 'model', None)
|
||||||
|
if model_name is not None:
|
||||||
|
if not isinstance(model_name, str):
|
||||||
|
model_name = str(model_name)
|
||||||
|
processed_llm = LLM(
|
||||||
|
model=model_name,
|
||||||
|
temperature=getattr(llm, 'temperature', None),
|
||||||
|
api_key=getattr(llm, 'api_key', None),
|
||||||
|
base_url=getattr(llm, 'base_url', None)
|
||||||
|
)
|
||||||
|
# If no valid LLM configuration found, leave as None for post_init_setup
|
||||||
|
|
||||||
|
# Initialize all fields in a dict
|
||||||
|
init_dict = {
|
||||||
|
"role": role,
|
||||||
|
"goal": goal,
|
||||||
|
"backstory": backstory,
|
||||||
|
"allow_delegation": allow_delegation,
|
||||||
|
"config": config,
|
||||||
|
"verbose": verbose,
|
||||||
|
"max_rpm": max_rpm,
|
||||||
|
"tools": processed_tools,
|
||||||
|
"max_iter": max_iter if max_iter is not None else 25,
|
||||||
|
"function_calling_llm": function_calling_llm,
|
||||||
|
"step_callback": step_callback,
|
||||||
|
"knowledge_sources": knowledge_sources,
|
||||||
|
**kwargs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Initialize base model with all fields
|
||||||
|
super().__init__(**init_dict)
|
||||||
|
|
||||||
|
# Store original values for interpolation
|
||||||
|
self._original_role = role
|
||||||
|
self._original_goal = goal
|
||||||
|
self._original_backstory = backstory
|
||||||
|
|
||||||
|
# Set LLM after base initialization to ensure proper model handling
|
||||||
|
self.llm = processed_llm
|
||||||
|
|
||||||
|
# Initialize private attributes
|
||||||
|
self._logger = Logger(verbose=self.verbose)
|
||||||
|
if self.max_rpm:
|
||||||
|
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||||
|
self._token_process = TokenProcess()
|
||||||
|
|
||||||
_times_executed: int = PrivateAttr(default=0)
|
_times_executed: int = PrivateAttr(default=0)
|
||||||
max_execution_time: Optional[int] = Field(
|
max_execution_time: Optional[int] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
@@ -138,21 +232,15 @@ class Agent(BaseAgent):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self._set_knowledge()
|
self._set_knowledge()
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role or "agent"
|
||||||
unaccepted_attributes = [
|
unaccepted_attributes = [
|
||||||
"AWS_ACCESS_KEY_ID",
|
"AWS_ACCESS_KEY_ID",
|
||||||
"AWS_SECRET_ACCESS_KEY",
|
"AWS_SECRET_ACCESS_KEY",
|
||||||
"AWS_REGION_NAME",
|
"AWS_REGION_NAME",
|
||||||
]
|
]
|
||||||
|
|
||||||
# Handle different cases for self.llm
|
# Handle LLM initialization if not already done
|
||||||
if isinstance(self.llm, str):
|
if self.llm is None:
|
||||||
# If it's a string, create an LLM instance
|
|
||||||
self.llm = LLM(model=self.llm)
|
|
||||||
elif isinstance(self.llm, LLM):
|
|
||||||
# If it's already an LLM instance, keep it as is
|
|
||||||
pass
|
|
||||||
elif self.llm is None:
|
|
||||||
# Determine the model name from environment variables or use default
|
# Determine the model name from environment variables or use default
|
||||||
model_name = (
|
model_name = (
|
||||||
os.environ.get("OPENAI_MODEL_NAME")
|
os.environ.get("OPENAI_MODEL_NAME")
|
||||||
@@ -190,9 +278,71 @@ class Agent(BaseAgent):
|
|||||||
if key not in ["prompt", "key_name", "default"]:
|
if key not in ["prompt", "key_name", "default"]:
|
||||||
# Only add default if the key is already set in os.environ
|
# Only add default if the key is already set in os.environ
|
||||||
if key in os.environ:
|
if key in os.environ:
|
||||||
llm_params[key] = value
|
try:
|
||||||
|
# Create a new dictionary for properly typed parameters
|
||||||
|
typed_params = {}
|
||||||
|
|
||||||
self.llm = LLM(**llm_params)
|
# Convert and validate values based on parameter type
|
||||||
|
if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
typed_params[key] = float(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']:
|
||||||
|
if value is not None:
|
||||||
|
try:
|
||||||
|
typed_params[key] = int(value)
|
||||||
|
except (ValueError, TypeError):
|
||||||
|
pass
|
||||||
|
elif key == 'logit_bias' and isinstance(value, str):
|
||||||
|
try:
|
||||||
|
bias_dict = {}
|
||||||
|
for pair in value.split(','):
|
||||||
|
token_id, bias = pair.split(':')
|
||||||
|
bias_dict[int(token_id.strip())] = float(bias.strip())
|
||||||
|
typed_params[key] = bias_dict
|
||||||
|
except (ValueError, AttributeError):
|
||||||
|
pass
|
||||||
|
elif key == 'response_format' and isinstance(value, str):
|
||||||
|
try:
|
||||||
|
import json
|
||||||
|
typed_params[key] = json.loads(value)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
pass
|
||||||
|
elif key == 'logprobs':
|
||||||
|
if value is not None:
|
||||||
|
typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value)
|
||||||
|
elif key == 'callbacks':
|
||||||
|
typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value
|
||||||
|
elif key == 'stop':
|
||||||
|
typed_params[key] = [value] if isinstance(value, str) else value
|
||||||
|
elif key in ['model', 'base_url', 'api_version', 'api_key']:
|
||||||
|
typed_params[key] = value
|
||||||
|
|
||||||
|
# Update llm_params with properly typed values
|
||||||
|
if typed_params:
|
||||||
|
llm_params.update(typed_params)
|
||||||
|
except (ValueError, AttributeError, json.JSONDecodeError):
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Create LLM instance with properly typed parameters
|
||||||
|
valid_params = {
|
||||||
|
'model', 'timeout', 'temperature', 'top_p', 'n', 'stop',
|
||||||
|
'max_completion_tokens', 'max_tokens', 'presence_penalty',
|
||||||
|
'frequency_penalty', 'logit_bias', 'response_format',
|
||||||
|
'seed', 'logprobs', 'top_logprobs', 'base_url',
|
||||||
|
'api_version', 'api_key', 'callbacks'
|
||||||
|
}
|
||||||
|
|
||||||
|
# Filter out None values and invalid parameters
|
||||||
|
filtered_params = {}
|
||||||
|
for k, v in llm_params.items():
|
||||||
|
if k in valid_params and v is not None:
|
||||||
|
filtered_params[k] = v
|
||||||
|
|
||||||
|
# Create LLM instance with properly typed parameters
|
||||||
|
self.llm = LLM(**filtered_params)
|
||||||
else:
|
else:
|
||||||
# For any other type, attempt to extract relevant attributes
|
# For any other type, attempt to extract relevant attributes
|
||||||
llm_params = {
|
llm_params = {
|
||||||
@@ -239,7 +389,7 @@ class Agent(BaseAgent):
|
|||||||
def _set_knowledge(self):
|
def _set_knowledge(self):
|
||||||
try:
|
try:
|
||||||
if self.knowledge_sources:
|
if self.knowledge_sources:
|
||||||
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
|
knowledge_agent_name = f"{(self.role or 'agent').replace(' ', '_')}"
|
||||||
if isinstance(self.knowledge_sources, list) and all(
|
if isinstance(self.knowledge_sources, list) and all(
|
||||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||||
):
|
):
|
||||||
@@ -384,6 +534,32 @@ class Agent(BaseAgent):
|
|||||||
self.response_template.split("{{ .Response }}")[1].strip()
|
self.response_template.split("{{ .Response }}")[1].strip()
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Ensure LLM is initialized with proper error handling
|
||||||
|
try:
|
||||||
|
if not self.llm:
|
||||||
|
self.llm = LLM(model="gpt-4")
|
||||||
|
if hasattr(self, '_logger'):
|
||||||
|
self._logger.debug("Initialized default LLM with gpt-4 model")
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(self, '_logger'):
|
||||||
|
self._logger.error(f"Failed to initialize LLM: {str(e)}")
|
||||||
|
raise
|
||||||
|
|
||||||
|
# Create token callback with proper error handling
|
||||||
|
try:
|
||||||
|
token_callback = None
|
||||||
|
if hasattr(self, '_token_process'):
|
||||||
|
token_callback = TokenCalcHandler(self._token_process)
|
||||||
|
except Exception as e:
|
||||||
|
if hasattr(self, '_logger'):
|
||||||
|
self._logger.warning(f"Failed to create token callback: {str(e)}")
|
||||||
|
token_callback = None
|
||||||
|
|
||||||
|
# Initialize callbacks list
|
||||||
|
executor_callbacks = []
|
||||||
|
if token_callback:
|
||||||
|
executor_callbacks.append(token_callback)
|
||||||
|
|
||||||
self.agent_executor = CrewAgentExecutor(
|
self.agent_executor = CrewAgentExecutor(
|
||||||
llm=self.llm,
|
llm=self.llm,
|
||||||
task=task,
|
task=task,
|
||||||
@@ -401,9 +577,9 @@ class Agent(BaseAgent):
|
|||||||
function_calling_llm=self.function_calling_llm,
|
function_calling_llm=self.function_calling_llm,
|
||||||
respect_context_window=self.respect_context_window,
|
respect_context_window=self.respect_context_window,
|
||||||
request_within_rpm_limit=(
|
request_within_rpm_limit=(
|
||||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
self._rpm_controller.check_or_wait if (hasattr(self, '_rpm_controller') and self._rpm_controller is not None) else None
|
||||||
),
|
),
|
||||||
callbacks=[TokenCalcHandler(self._token_process)],
|
callbacks=executor_callbacks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
|
|||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.agents.cache.cache_handler import CacheHandler
|
from crewai.agents.cache.cache_handler import CacheHandler
|
||||||
from crewai.agents.tools_handler import ToolsHandler
|
from crewai.agents.tools_handler import ToolsHandler
|
||||||
|
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
from crewai.tools.base_tool import Tool
|
from crewai.tools.base_tool import Tool
|
||||||
from crewai.utilities import I18N, Logger, RPMController
|
from crewai.utilities import I18N, Logger, RPMController
|
||||||
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
formatting_errors: int = Field(
|
formatting_errors: int = Field(
|
||||||
default=0, description="Number of formatting errors."
|
default=0, description="Number of formatting errors."
|
||||||
)
|
)
|
||||||
role: str = Field(description="Role of the agent")
|
role: Optional[str] = Field(default=None, description="Role of the agent")
|
||||||
goal: str = Field(description="Objective of the agent")
|
goal: Optional[str] = Field(default=None, description="Objective of the agent")
|
||||||
backstory: str = Field(description="Backstory of the agent")
|
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
|
||||||
config: Optional[Dict[str, Any]] = Field(
|
config: Optional[Dict[str, Any]] = Field(
|
||||||
description="Configuration for the agent", default=None, exclude=True
|
description="Configuration for the agent", default=None, exclude=True
|
||||||
)
|
)
|
||||||
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
max_tokens: Optional[int] = Field(
|
max_tokens: Optional[int] = Field(
|
||||||
default=None, description="Maximum number of tokens for the agent's execution."
|
default=None, description="Maximum number of tokens for the agent's execution."
|
||||||
)
|
)
|
||||||
|
function_calling_llm: Optional[Any] = Field(
|
||||||
|
default=None, description="Language model for function calling."
|
||||||
|
)
|
||||||
|
step_callback: Optional[Any] = Field(
|
||||||
|
default=None, description="Callback for execution steps."
|
||||||
|
)
|
||||||
|
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||||
|
default=None, description="Knowledge sources for the agent."
|
||||||
|
)
|
||||||
|
|
||||||
|
model_config = {
|
||||||
|
"arbitrary_types_allowed": True,
|
||||||
|
"extra": "allow", # Allow extra fields in constructor
|
||||||
|
}
|
||||||
|
|
||||||
@model_validator(mode="before")
|
@model_validator(mode="before")
|
||||||
@classmethod
|
@classmethod
|
||||||
def process_model_config(cls, values):
|
def process_model_config(cls, values):
|
||||||
|
"""Process configuration values before model initialization."""
|
||||||
return process_config(values, cls)
|
return process_config(values, cls)
|
||||||
|
|
||||||
@field_validator("tools")
|
@field_validator("tools")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
|
||||||
"""Validate and process the tools provided to the agent.
|
"""Validate and process the tools provided to the agent.
|
||||||
|
|
||||||
This method ensures that each tool is either an instance of BaseTool
|
This method ensures that each tool is either an instance of BaseTool,
|
||||||
or an object with 'name', 'func', and 'description' attributes. If the
|
a function decorated with @tool, or an object with 'name', 'func',
|
||||||
tool meets these criteria, it is processed and added to the list of
|
and 'description' attributes. If the tool meets these criteria, it is
|
||||||
tools. Otherwise, a ValueError is raised.
|
processed and added to the list of tools. Otherwise, a ValueError is raised.
|
||||||
"""
|
"""
|
||||||
|
if not tools:
|
||||||
|
return []
|
||||||
|
|
||||||
processed_tools = []
|
processed_tools = []
|
||||||
for tool in tools:
|
for tool in tools:
|
||||||
if isinstance(tool, BaseTool):
|
if isinstance(tool, BaseTool):
|
||||||
processed_tools.append(tool)
|
processed_tools.append(tool)
|
||||||
|
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
|
||||||
|
# Handle @tool decorated functions
|
||||||
|
processed_tools.append(Tool.from_function(tool))
|
||||||
elif (
|
elif (
|
||||||
hasattr(tool, "name")
|
hasattr(tool, "name")
|
||||||
and hasattr(tool, "func")
|
and hasattr(tool, "func")
|
||||||
@@ -160,28 +182,51 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
else:
|
else:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"Invalid tool type: {type(tool)}. "
|
f"Invalid tool type: {type(tool)}. "
|
||||||
"Tool must be an instance of BaseTool or "
|
"Tool must be an instance of BaseTool, a @tool decorated function, "
|
||||||
"an object with 'name', 'func', and 'description' attributes."
|
"or an object with 'name', 'func', and 'description' attributes."
|
||||||
)
|
)
|
||||||
return processed_tools
|
return processed_tools
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def validate_and_set_attributes(self):
|
def validate_and_set_attributes(self):
|
||||||
# Validate required fields
|
"""Validate and set attributes for the agent.
|
||||||
for field in ["role", "goal", "backstory"]:
|
|
||||||
if getattr(self, field) is None:
|
|
||||||
raise ValueError(
|
|
||||||
f"{field} must be provided either directly or through config"
|
|
||||||
)
|
|
||||||
|
|
||||||
# Set private attributes
|
This method ensures that attributes are properly set and initialized,
|
||||||
|
either from direct parameters or configuration.
|
||||||
|
"""
|
||||||
|
# Store original values for interpolation
|
||||||
|
self._original_role = self.role
|
||||||
|
self._original_goal = self.goal
|
||||||
|
self._original_backstory = self.backstory
|
||||||
|
|
||||||
|
# Process config if provided
|
||||||
|
if self.config:
|
||||||
|
config_data = self.config
|
||||||
|
if isinstance(config_data, str):
|
||||||
|
import json
|
||||||
|
try:
|
||||||
|
config_data = json.loads(config_data)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
raise ValueError("Invalid JSON in config")
|
||||||
|
|
||||||
|
# Update fields from config if they're None
|
||||||
|
for field in ["role", "goal", "backstory"]:
|
||||||
|
if field in config_data and getattr(self, field) is None:
|
||||||
|
setattr(self, field, config_data[field])
|
||||||
|
|
||||||
|
# Set default values for required fields if they're still None
|
||||||
|
self.role = self.role or "Assistant"
|
||||||
|
self.goal = self.goal or "Help the user accomplish their tasks"
|
||||||
|
self.backstory = self.backstory or "I am an AI assistant ready to help"
|
||||||
|
|
||||||
|
# Initialize tools handler if not set
|
||||||
|
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
|
||||||
|
self.tools_handler = ToolsHandler()
|
||||||
|
|
||||||
|
# Initialize logger and rpm controller
|
||||||
self._logger = Logger(verbose=self.verbose)
|
self._logger = Logger(verbose=self.verbose)
|
||||||
if self.max_rpm and not self._rpm_controller:
|
if self.max_rpm:
|
||||||
self._rpm_controller = RPMController(
|
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||||
max_rpm=self.max_rpm, logger=self._logger
|
|
||||||
)
|
|
||||||
if not self._token_process:
|
|
||||||
self._token_process = TokenProcess()
|
|
||||||
|
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
@property
|
@property
|
||||||
def key(self):
|
def key(self):
|
||||||
source = [
|
source = [
|
||||||
self._original_role or self.role,
|
str(self._original_role or self.role or ""),
|
||||||
self._original_goal or self.goal,
|
str(self._original_goal or self.goal or ""),
|
||||||
self._original_backstory or self.backstory,
|
str(self._original_backstory or self.backstory or ""),
|
||||||
]
|
]
|
||||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||||
|
|
||||||
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
"tools_handler",
|
"tools_handler",
|
||||||
"cache_handler",
|
"cache_handler",
|
||||||
"llm",
|
"llm",
|
||||||
|
"function_calling_llm",
|
||||||
}
|
}
|
||||||
|
|
||||||
# Copy llm and clear callbacks
|
# Copy LLMs and clear callbacks
|
||||||
existing_llm = shallow_copy(self.llm)
|
existing_llm = shallow_copy(self.llm) if self.llm else None
|
||||||
|
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
|
||||||
|
|
||||||
|
# Create base data
|
||||||
copied_data = self.model_dump(exclude=exclude)
|
copied_data = self.model_dump(exclude=exclude)
|
||||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||||
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
|
|
||||||
|
# Create new instance with copied data
|
||||||
|
copied_agent = type(self)(
|
||||||
|
**copied_data,
|
||||||
|
llm=existing_llm,
|
||||||
|
function_calling_llm=existing_function_calling_llm,
|
||||||
|
tools=self.tools
|
||||||
|
)
|
||||||
|
|
||||||
|
# Copy private attributes
|
||||||
|
copied_agent._original_role = self._original_role
|
||||||
|
copied_agent._original_goal = self._original_goal
|
||||||
|
copied_agent._original_backstory = self._original_backstory
|
||||||
|
|
||||||
return copied_agent
|
return copied_agent
|
||||||
|
|
||||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||||
"""Interpolate inputs into the agent description and backstory."""
|
"""Interpolate inputs into the agent description and backstory."""
|
||||||
if self._original_role is None:
|
if self._original_role is None:
|
||||||
self._original_role = self.role
|
self._original_role = self.role or ""
|
||||||
if self._original_goal is None:
|
if self._original_goal is None:
|
||||||
self._original_goal = self.goal
|
self._original_goal = self.goal or ""
|
||||||
if self._original_backstory is None:
|
if self._original_backstory is None:
|
||||||
self._original_backstory = self.backstory
|
self._original_backstory = self.backstory or ""
|
||||||
|
|
||||||
if inputs:
|
if inputs:
|
||||||
self.role = self._original_role.format(**inputs)
|
self.role = self._original_role.format(**inputs) if self._original_role else None
|
||||||
self.goal = self._original_goal.format(**inputs)
|
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
|
||||||
self.backstory = self._original_backstory.format(**inputs)
|
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
|
||||||
|
|
||||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||||
"""Set the cache handler for the agent.
|
"""Set the cache handler for the agent.
|
||||||
|
|||||||
@@ -82,16 +82,17 @@ class CrewAgentExecutorMixin:
|
|||||||
)
|
)
|
||||||
self.crew._long_term_memory.save(long_term_memory)
|
self.crew._long_term_memory.save(long_term_memory)
|
||||||
|
|
||||||
for entity in evaluation.entities:
|
if hasattr(evaluation, 'entities') and evaluation.entities:
|
||||||
entity_memory = EntityMemoryItem(
|
for entity in evaluation.entities:
|
||||||
name=entity.name,
|
entity_memory = EntityMemoryItem(
|
||||||
type=entity.type,
|
name=entity.name,
|
||||||
description=entity.description,
|
type=entity.type,
|
||||||
relationships="\n".join(
|
description=entity.description,
|
||||||
[f"- {r}" for r in entity.relationships]
|
relationships="\n".join(
|
||||||
),
|
[f"- {r}" for r in entity.relationships]
|
||||||
)
|
),
|
||||||
self.crew._entity_memory.save(entity_memory)
|
)
|
||||||
|
self.crew._entity_memory.save(entity_memory)
|
||||||
except AttributeError as e:
|
except AttributeError as e:
|
||||||
print(f"Missing attributes for long term memory: {e}")
|
print(f"Missing attributes for long term memory: {e}")
|
||||||
pass
|
pass
|
||||||
|
|||||||
@@ -68,7 +68,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
self.tools_handler = tools_handler
|
self.tools_handler = tools_handler
|
||||||
self.original_tools = original_tools
|
self.original_tools = original_tools
|
||||||
self.step_callback = step_callback
|
self.step_callback = step_callback
|
||||||
self.use_stop_words = self.llm.supports_stop_words()
|
self.use_stop_words = self.llm.supports_stop_words() if self.llm else False
|
||||||
self.tools_description = tools_description
|
self.tools_description = tools_description
|
||||||
self.function_calling_llm = function_calling_llm
|
self.function_calling_llm = function_calling_llm
|
||||||
self.respect_context_window = respect_context_window
|
self.respect_context_window = respect_context_window
|
||||||
@@ -147,7 +147,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
# Directly append the result to the messages if the
|
# Directly append the result to the messages if the
|
||||||
# tool is "Add image to content" in case of multimodal
|
# tool is "Add image to content" in case of multimodal
|
||||||
# agents
|
# agents
|
||||||
if formatted_answer.tool == self._i18n.tools("add_image")["name"]:
|
add_image_tool_name = self._i18n.tools("add_image")
|
||||||
|
if add_image_tool_name and formatted_answer.tool == add_image_tool_name:
|
||||||
self.messages.append(tool_result.result)
|
self.messages.append(tool_result.result)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
@@ -214,13 +215,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if self.agent.verbose or (
|
if self.agent.verbose or (
|
||||||
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
||||||
):
|
):
|
||||||
agent_role = self.agent.role.split("\n")[0]
|
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
||||||
)
|
)
|
||||||
self._printer.print(
|
if self.task and self.task.description:
|
||||||
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
|
self._printer.print(
|
||||||
)
|
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
|
||||||
|
)
|
||||||
|
|
||||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||||
if self.agent is None:
|
if self.agent is None:
|
||||||
@@ -228,7 +230,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
|||||||
if self.agent.verbose or (
|
if self.agent.verbose or (
|
||||||
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
||||||
):
|
):
|
||||||
agent_role = self.agent.role.split("\n")[0]
|
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
|
||||||
if isinstance(formatted_answer, AgentAction):
|
if isinstance(formatted_answer, AgentAction):
|
||||||
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
|
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
|
||||||
formatted_json = json.dumps(
|
formatted_json = json.dumps(
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
|||||||
|
|
||||||
import click
|
import click
|
||||||
import requests
|
import requests
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||||
|
|
||||||
@@ -192,7 +193,7 @@ def download_data(response):
|
|||||||
data_chunks = []
|
data_chunks = []
|
||||||
with click.progressbar(
|
with click.progressbar(
|
||||||
length=total_size, label="Downloading", show_pos=True
|
length=total_size, label="Downloading", show_pos=True
|
||||||
) as progress_bar:
|
) as progress_bar: # type: Any
|
||||||
for chunk in response.iter_content(block_size):
|
for chunk in response.iter_content(block_size):
|
||||||
if chunk:
|
if chunk:
|
||||||
data_chunks.append(chunk)
|
data_chunks.append(chunk)
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ from concurrent.futures import Future
|
|||||||
from hashlib import md5
|
from hashlib import md5
|
||||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from crewai.tools.base_tool import BaseTool
|
||||||
|
|
||||||
from pydantic import (
|
from pydantic import (
|
||||||
UUID4,
|
UUID4,
|
||||||
BaseModel,
|
BaseModel,
|
||||||
@@ -728,7 +730,7 @@ class Crew(BaseModel):
|
|||||||
tools_for_task = task.tools or agent_to_use.tools or []
|
tools_for_task = task.tools or agent_to_use.tools or []
|
||||||
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
||||||
|
|
||||||
self._log_task_start(task, agent_to_use.role)
|
self._log_task_start(task, agent_to_use.role if agent_to_use and agent_to_use.role else "")
|
||||||
|
|
||||||
if isinstance(task, ConditionalTask):
|
if isinstance(task, ConditionalTask):
|
||||||
skipped_task_output = self._handle_conditional_task(
|
skipped_task_output = self._handle_conditional_task(
|
||||||
@@ -794,8 +796,8 @@ class Crew(BaseModel):
|
|||||||
return None
|
return None
|
||||||
|
|
||||||
def _prepare_tools(
|
def _prepare_tools(
|
||||||
self, agent: BaseAgent, task: Task, tools: List[Tool]
|
self, agent: BaseAgent, task: Task, tools: List[Union[Tool, BaseTool]]
|
||||||
) -> List[Tool]:
|
) -> List[Union[Tool, BaseTool]]:
|
||||||
# Add delegation tools if agent allows delegation
|
# Add delegation tools if agent allows delegation
|
||||||
if agent.allow_delegation:
|
if agent.allow_delegation:
|
||||||
if self.process == Process.hierarchical:
|
if self.process == Process.hierarchical:
|
||||||
@@ -824,8 +826,8 @@ class Crew(BaseModel):
|
|||||||
return task.agent
|
return task.agent
|
||||||
|
|
||||||
def _merge_tools(
|
def _merge_tools(
|
||||||
self, existing_tools: List[Tool], new_tools: List[Tool]
|
self, existing_tools: List[Union[Tool, BaseTool]], new_tools: List[Union[Tool, BaseTool]]
|
||||||
) -> List[Tool]:
|
) -> List[Union[Tool, BaseTool]]:
|
||||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||||
if not new_tools:
|
if not new_tools:
|
||||||
return existing_tools
|
return existing_tools
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
@@ -23,14 +23,25 @@ class CrewOutput(BaseModel):
|
|||||||
)
|
)
|
||||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||||
|
|
||||||
@property
|
def json(
|
||||||
def json(self) -> Optional[str]:
|
self,
|
||||||
|
*,
|
||||||
|
include: Union[set[str], None] = None,
|
||||||
|
exclude: Union[set[str], None] = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
encoder: Optional[Callable[[Any], Any]] = None,
|
||||||
|
models_as_dict: bool = True,
|
||||||
|
**dumps_kwargs: Any,
|
||||||
|
) -> str:
|
||||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||||
)
|
)
|
||||||
|
|
||||||
return json.dumps(self.json_dict)
|
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert json_output and pydantic_output to a dictionary."""
|
"""Convert json_output and pydantic_output to a dictionary."""
|
||||||
|
|||||||
@@ -106,7 +106,12 @@ class FlowPlot:
|
|||||||
|
|
||||||
# Add nodes to the network
|
# Add nodes to the network
|
||||||
try:
|
try:
|
||||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
add_nodes_to_network(
|
||||||
|
net,
|
||||||
|
flow=self.flow,
|
||||||
|
pos=node_positions,
|
||||||
|
node_styles=self.node_styles
|
||||||
|
)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ import warnings
|
|||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.simplefilter("ignore", UserWarning)
|
warnings.simplefilter("ignore", UserWarning)
|
||||||
import litellm
|
import litellm
|
||||||
@@ -93,10 +95,33 @@ def suppress_warnings():
|
|||||||
sys.stderr = old_stderr
|
sys.stderr = old_stderr
|
||||||
|
|
||||||
|
|
||||||
class LLM:
|
class LLM(BaseModel):
|
||||||
|
model: str = "gpt-4" # Set default model
|
||||||
|
timeout: Optional[Union[float, int]] = None
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stop: Optional[Union[str, List[str]]] = None
|
||||||
|
max_completion_tokens: Optional[int] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
logit_bias: Optional[Dict[int, float]] = None
|
||||||
|
response_format: Optional[Dict[str, Any]] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
logprobs: Optional[bool] = None
|
||||||
|
top_logprobs: Optional[int] = None
|
||||||
|
base_url: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
callbacks: Optional[List[Any]] = None
|
||||||
|
context_window_size: Optional[int] = None
|
||||||
|
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: Optional[Union[str, 'LLM']] = "gpt-4",
|
||||||
timeout: Optional[Union[float, int]] = None,
|
timeout: Optional[Union[float, int]] = None,
|
||||||
temperature: Optional[float] = None,
|
temperature: Optional[float] = None,
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
@@ -114,118 +139,427 @@ class LLM:
|
|||||||
base_url: Optional[str] = None,
|
base_url: Optional[str] = None,
|
||||||
api_version: Optional[str] = None,
|
api_version: Optional[str] = None,
|
||||||
api_key: Optional[str] = None,
|
api_key: Optional[str] = None,
|
||||||
callbacks: List[Any] = [],
|
callbacks: Optional[List[Any]] = None,
|
||||||
**kwargs,
|
context_window_size: Optional[int] = None,
|
||||||
):
|
**kwargs: Any,
|
||||||
self.model = model
|
) -> None:
|
||||||
self.timeout = timeout
|
# Initialize with default values
|
||||||
self.temperature = temperature
|
init_dict = {
|
||||||
self.top_p = top_p
|
"model": model if isinstance(model, str) else "gpt-4",
|
||||||
self.n = n
|
"timeout": timeout,
|
||||||
self.stop = stop
|
"temperature": temperature,
|
||||||
self.max_completion_tokens = max_completion_tokens
|
"top_p": top_p,
|
||||||
self.max_tokens = max_tokens
|
"n": n,
|
||||||
self.presence_penalty = presence_penalty
|
"stop": stop,
|
||||||
self.frequency_penalty = frequency_penalty
|
"max_completion_tokens": max_completion_tokens,
|
||||||
self.logit_bias = logit_bias
|
"max_tokens": max_tokens,
|
||||||
self.response_format = response_format
|
"presence_penalty": presence_penalty,
|
||||||
self.seed = seed
|
"frequency_penalty": frequency_penalty,
|
||||||
self.logprobs = logprobs
|
"logit_bias": logit_bias,
|
||||||
self.top_logprobs = top_logprobs
|
"response_format": response_format,
|
||||||
self.base_url = base_url
|
"seed": seed,
|
||||||
self.api_version = api_version
|
"logprobs": logprobs,
|
||||||
self.api_key = api_key
|
"top_logprobs": top_logprobs,
|
||||||
self.callbacks = callbacks
|
"base_url": base_url,
|
||||||
self.context_window_size = 0
|
"api_version": api_version,
|
||||||
self.kwargs = kwargs
|
"api_key": api_key,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"context_window_size": context_window_size,
|
||||||
|
"kwargs": kwargs,
|
||||||
|
}
|
||||||
|
super().__init__(**init_dict)
|
||||||
|
|
||||||
|
# Initialize model with default value
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
|
||||||
|
# Extract and validate model name
|
||||||
|
if isinstance(model, LLM):
|
||||||
|
# Extract and validate model name from LLM instance
|
||||||
|
if hasattr(model, 'model'):
|
||||||
|
if isinstance(model.model, str):
|
||||||
|
self.model = model.model
|
||||||
|
else:
|
||||||
|
# Try to extract string model name from nested LLM
|
||||||
|
if isinstance(model.model, LLM):
|
||||||
|
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
# Extract and validate model name for non-LLM instances
|
||||||
|
if not isinstance(model, str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||||
|
if model is not None:
|
||||||
|
if hasattr(model, 'model_name'):
|
||||||
|
model_name = getattr(model, 'model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
elif hasattr(model, 'model'):
|
||||||
|
model_attr = getattr(model, 'model', None)
|
||||||
|
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||||
|
elif hasattr(model, '_model_name'):
|
||||||
|
model_name = getattr(model, '_model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback for None
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model is None, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = str(model) # Ensure it's a string
|
||||||
|
|
||||||
|
# If model is an LLM instance, copy its configuration
|
||||||
|
if isinstance(model, LLM):
|
||||||
|
# Extract and validate model name first
|
||||||
|
if hasattr(model, 'model'):
|
||||||
|
if isinstance(model.model, str):
|
||||||
|
self.model = model.model
|
||||||
|
else:
|
||||||
|
# Try to extract string model name from nested LLM
|
||||||
|
if isinstance(model.model, LLM):
|
||||||
|
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||||
|
|
||||||
|
# Copy other configuration
|
||||||
|
self.timeout = model.timeout
|
||||||
|
self.temperature = model.temperature
|
||||||
|
self.top_p = model.top_p
|
||||||
|
self.n = model.n
|
||||||
|
self.stop = model.stop
|
||||||
|
self.max_completion_tokens = model.max_completion_tokens
|
||||||
|
self.max_tokens = model.max_tokens
|
||||||
|
self.presence_penalty = model.presence_penalty
|
||||||
|
self.frequency_penalty = model.frequency_penalty
|
||||||
|
self.logit_bias = model.logit_bias
|
||||||
|
self.response_format = model.response_format
|
||||||
|
self.seed = model.seed
|
||||||
|
self.logprobs = model.logprobs
|
||||||
|
self.top_logprobs = model.top_logprobs
|
||||||
|
self.base_url = model.base_url
|
||||||
|
self.api_version = model.api_version
|
||||||
|
self.api_key = model.api_key
|
||||||
|
self.callbacks = model.callbacks
|
||||||
|
self.context_window_size = model.context_window_size
|
||||||
|
self.kwargs = model.kwargs
|
||||||
|
|
||||||
|
# Final validation of model name
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
# Extract and validate model name for non-LLM instances
|
||||||
|
if not isinstance(model, str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||||
|
if model is not None:
|
||||||
|
if hasattr(model, 'model_name'):
|
||||||
|
model_name = getattr(model, 'model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
elif hasattr(model, 'model'):
|
||||||
|
model_attr = getattr(model, 'model', None)
|
||||||
|
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||||
|
elif hasattr(model, '_model_name'):
|
||||||
|
model_name = getattr(model, '_model_name', None)
|
||||||
|
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4" # Default fallback for None
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model is None, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = str(model) # Ensure it's a string
|
||||||
|
|
||||||
|
# Final validation
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
|
||||||
|
|
||||||
|
self.timeout = timeout
|
||||||
|
self.temperature = temperature
|
||||||
|
self.top_p = top_p
|
||||||
|
self.n = n
|
||||||
|
self.stop = stop
|
||||||
|
self.max_completion_tokens = max_completion_tokens
|
||||||
|
self.max_tokens = max_tokens
|
||||||
|
self.presence_penalty = presence_penalty
|
||||||
|
self.frequency_penalty = frequency_penalty
|
||||||
|
self.logit_bias = logit_bias
|
||||||
|
self.response_format = response_format
|
||||||
|
self.seed = seed
|
||||||
|
self.logprobs = logprobs
|
||||||
|
self.top_logprobs = top_logprobs
|
||||||
|
self.base_url = base_url
|
||||||
|
self.api_version = api_version
|
||||||
|
self.api_key = api_key
|
||||||
|
self.callbacks = callbacks
|
||||||
|
self.context_window_size = 0
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
|
# Ensure model is a string after initialization
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
self.model = "gpt-4"
|
||||||
|
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
|
||||||
|
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
|
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
self.set_env_callbacks()
|
self.set_env_callbacks()
|
||||||
|
|
||||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
def call(
|
||||||
|
self,
|
||||||
|
messages: List[Dict[str, str]],
|
||||||
|
callbacks: Optional[List[Any]] = None
|
||||||
|
) -> str:
|
||||||
with suppress_warnings():
|
with suppress_warnings():
|
||||||
if callbacks and len(callbacks) > 0:
|
if callbacks and len(callbacks) > 0:
|
||||||
self.set_callbacks(callbacks)
|
self.set_callbacks(callbacks)
|
||||||
|
|
||||||
|
# Store original model to restore later
|
||||||
|
original_model = self.model
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
# Ensure model is a string before making the call
|
||||||
|
if not isinstance(self.model, str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
|
||||||
|
if isinstance(self.model, LLM):
|
||||||
|
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
|
||||||
|
elif hasattr(self.model, 'model_name'):
|
||||||
|
self.model = str(self.model.model_name)
|
||||||
|
elif hasattr(self.model, 'model'):
|
||||||
|
if isinstance(self.model.model, str):
|
||||||
|
self.model = str(self.model.model)
|
||||||
|
elif hasattr(self.model.model, 'model_name'):
|
||||||
|
self.model = str(self.model.model.model_name)
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
|
||||||
|
else:
|
||||||
|
self.model = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
|
||||||
|
|
||||||
|
# Create base params with validated model name
|
||||||
|
# Extract model name string
|
||||||
|
model_name = None
|
||||||
|
if isinstance(self.model, str):
|
||||||
|
model_name = self.model
|
||||||
|
elif hasattr(self.model, 'model_name'):
|
||||||
|
model_name = str(self.model.model_name)
|
||||||
|
elif hasattr(self.model, 'model'):
|
||||||
|
if isinstance(self.model.model, str):
|
||||||
|
model_name = str(self.model.model)
|
||||||
|
elif hasattr(self.model.model, 'model_name'):
|
||||||
|
model_name = str(self.model.model.model_name)
|
||||||
|
|
||||||
|
if not model_name:
|
||||||
|
model_name = "gpt-4"
|
||||||
|
if self.logger:
|
||||||
|
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||||
|
|
||||||
params = {
|
params = {
|
||||||
"model": self.model,
|
"model": model_name,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"timeout": self.timeout,
|
"stream": False,
|
||||||
"temperature": self.temperature,
|
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
|
||||||
"top_p": self.top_p,
|
|
||||||
"n": self.n,
|
|
||||||
"stop": self.stop,
|
|
||||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
|
||||||
"presence_penalty": self.presence_penalty,
|
|
||||||
"frequency_penalty": self.frequency_penalty,
|
|
||||||
"logit_bias": self.logit_bias,
|
|
||||||
"response_format": self.response_format,
|
|
||||||
"seed": self.seed,
|
|
||||||
"logprobs": self.logprobs,
|
|
||||||
"top_logprobs": self.top_logprobs,
|
|
||||||
"api_base": self.base_url,
|
"api_base": self.base_url,
|
||||||
"api_version": self.api_version,
|
"api_version": self.api_version,
|
||||||
"api_key": self.api_key,
|
|
||||||
"stream": False,
|
|
||||||
**self.kwargs,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.logger:
|
||||||
|
self.logger.debug(f"Using model parameters: {params}")
|
||||||
|
|
||||||
|
# Add API configuration if available
|
||||||
|
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||||
|
if api_key:
|
||||||
|
params["api_key"] = api_key
|
||||||
|
|
||||||
|
# Try to get supported parameters for the model
|
||||||
|
try:
|
||||||
|
supported_params = get_supported_openai_params(self.model)
|
||||||
|
optional_params = {}
|
||||||
|
|
||||||
|
if supported_params:
|
||||||
|
param_mapping = {
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"seed": self.seed,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"top_logprobs": self.top_logprobs
|
||||||
|
}
|
||||||
|
|
||||||
|
# Only add parameters that are supported and not None
|
||||||
|
optional_params = {
|
||||||
|
k: v for k, v in param_mapping.items()
|
||||||
|
if k in supported_params and v is not None
|
||||||
|
}
|
||||||
|
if "logprobs" in supported_params and self.logprobs is not None:
|
||||||
|
optional_params["logprobs"] = self.logprobs
|
||||||
|
if "top_logprobs" in supported_params and self.top_logprobs is not None:
|
||||||
|
optional_params["top_logprobs"] = self.top_logprobs
|
||||||
|
except Exception as e:
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
|
||||||
|
# If we can't get supported params, just add non-None parameters
|
||||||
|
param_mapping = {
|
||||||
|
"timeout": self.timeout,
|
||||||
|
"temperature": self.temperature,
|
||||||
|
"top_p": self.top_p,
|
||||||
|
"n": self.n,
|
||||||
|
"stop": self.stop,
|
||||||
|
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||||
|
"presence_penalty": self.presence_penalty,
|
||||||
|
"frequency_penalty": self.frequency_penalty,
|
||||||
|
"logit_bias": self.logit_bias,
|
||||||
|
"response_format": self.response_format,
|
||||||
|
"seed": self.seed,
|
||||||
|
"logprobs": self.logprobs,
|
||||||
|
"top_logprobs": self.top_logprobs
|
||||||
|
}
|
||||||
|
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
|
||||||
|
|
||||||
|
# Update params with optional parameters
|
||||||
|
params.update(optional_params)
|
||||||
|
|
||||||
|
# Add API endpoint configuration if available
|
||||||
|
if self.base_url:
|
||||||
|
params["api_base"] = self.base_url
|
||||||
|
if self.api_version:
|
||||||
|
params["api_version"] = self.api_version
|
||||||
|
|
||||||
|
# Final validation of model parameter
|
||||||
|
if not isinstance(params["model"], str):
|
||||||
|
if self.logger:
|
||||||
|
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
|
||||||
|
params["model"] = "gpt-4"
|
||||||
|
|
||||||
|
# Update params with non-None optional parameters
|
||||||
|
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||||
|
|
||||||
|
# Add any additional kwargs
|
||||||
|
if self.kwargs:
|
||||||
|
params.update(self.kwargs)
|
||||||
|
|
||||||
# Remove None values to avoid passing unnecessary parameters
|
# Remove None values to avoid passing unnecessary parameters
|
||||||
params = {k: v for k, v in params.items() if v is not None}
|
params = {k: v for k, v in params.items() if v is not None}
|
||||||
|
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
return response["choices"][0]["message"]["content"]
|
content = response["choices"][0]["message"]["content"]
|
||||||
|
|
||||||
|
# Extract usage metrics
|
||||||
|
usage = response.get("usage", {})
|
||||||
|
if callbacks:
|
||||||
|
for callback in callbacks:
|
||||||
|
if hasattr(callback, "update_token_usage"):
|
||||||
|
callback.update_token_usage(usage)
|
||||||
|
|
||||||
|
return content
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
if not LLMContextLengthExceededException(
|
if not LLMContextLengthExceededException(
|
||||||
str(e)
|
str(e)
|
||||||
)._is_context_limit_error(str(e)):
|
)._is_context_limit_error(str(e)):
|
||||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||||
|
|
||||||
raise # Re-raise the exception after logging
|
raise # Re-raise the exception after logging
|
||||||
|
finally:
|
||||||
|
# Always restore the original model object
|
||||||
|
self.model = original_model
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if the model supports function calling, False otherwise
|
||||||
|
"""
|
||||||
try:
|
try:
|
||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return "response_format" in params
|
return "response_format" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Check if the LLM supports stop words.
|
||||||
|
Returns False if the LLM is not properly initialized."""
|
||||||
|
if not hasattr(self, 'model') or self.model is None:
|
||||||
|
return False
|
||||||
try:
|
try:
|
||||||
params = get_supported_openai_params(model=self.model)
|
params = get_supported_openai_params(model=self.model)
|
||||||
return "stop" in params
|
return "stop" in params
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logging.error(f"Failed to get supported params: {str(e)}")
|
if self.logger:
|
||||||
|
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
"""Get the context window size for the current model.
|
||||||
if self.context_window_size != 0:
|
|
||||||
return self.context_window_size
|
|
||||||
|
|
||||||
self.context_window_size = int(
|
Returns:
|
||||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
int: The context window size in tokens
|
||||||
)
|
"""
|
||||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||||
if self.model.startswith(key):
|
if self.context_window_size is not None and self.context_window_size != 0:
|
||||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
return int(self.context_window_size)
|
||||||
|
|
||||||
|
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
|
||||||
|
if isinstance(self.model, str):
|
||||||
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||||
|
if self.model.startswith(key):
|
||||||
|
window_size = value
|
||||||
|
break
|
||||||
|
|
||||||
|
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
def set_callbacks(self, callbacks: List[Any]):
|
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
|
||||||
callback_types = [type(callback) for callback in callbacks]
|
"""Set callbacks for the LLM.
|
||||||
for callback in litellm.success_callback[:]:
|
|
||||||
if type(callback) in callback_types:
|
|
||||||
litellm.success_callback.remove(callback)
|
|
||||||
|
|
||||||
for callback in litellm._async_success_callback[:]:
|
Args:
|
||||||
if type(callback) in callback_types:
|
callbacks: Optional list of callback functions. If None, no callbacks will be set.
|
||||||
litellm._async_success_callback.remove(callback)
|
"""
|
||||||
|
if callbacks is not None:
|
||||||
|
callback_types = [type(callback) for callback in callbacks]
|
||||||
|
for callback in litellm.success_callback[:]:
|
||||||
|
if type(callback) in callback_types:
|
||||||
|
litellm.success_callback.remove(callback)
|
||||||
|
|
||||||
litellm.callbacks = callbacks
|
for callback in litellm._async_success_callback[:]:
|
||||||
|
if type(callback) in callback_types:
|
||||||
|
litellm._async_success_callback.remove(callback)
|
||||||
|
|
||||||
|
litellm.callbacks = callbacks
|
||||||
|
|
||||||
def set_env_callbacks(self):
|
def set_env_callbacks(self):
|
||||||
"""
|
"""
|
||||||
|
|||||||
@@ -269,7 +269,9 @@ class Task(BaseModel):
|
|||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def check_tools(self):
|
def check_tools(self):
|
||||||
"""Check if the tools are set."""
|
"""Check if the tools are set."""
|
||||||
if not self.tools and self.agent and self.agent.tools:
|
if self.agent and self.agent.tools:
|
||||||
|
if self.tools is None:
|
||||||
|
self.tools = []
|
||||||
self.tools.extend(self.agent.tools)
|
self.tools.extend(self.agent.tools)
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@@ -348,7 +350,8 @@ class Task(BaseModel):
|
|||||||
self.prompt_context = context
|
self.prompt_context = context
|
||||||
tools = tools or self.tools or []
|
tools = tools or self.tools or []
|
||||||
|
|
||||||
self.processed_by_agents.add(agent.role)
|
if agent and agent.role:
|
||||||
|
self.processed_by_agents.add(agent.role)
|
||||||
|
|
||||||
result = agent.execute_task(
|
result = agent.execute_task(
|
||||||
task=self,
|
task=self,
|
||||||
|
|||||||
@@ -1,5 +1,5 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Callable, Dict, Optional, Union
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, model_validator
|
from pydantic import BaseModel, Field, model_validator
|
||||||
|
|
||||||
@@ -34,8 +34,19 @@ class TaskOutput(BaseModel):
|
|||||||
self.summary = f"{excerpt}..."
|
self.summary = f"{excerpt}..."
|
||||||
return self
|
return self
|
||||||
|
|
||||||
@property
|
def json(
|
||||||
def json(self) -> Optional[str]:
|
self,
|
||||||
|
*,
|
||||||
|
include: Union[set[str], None] = None,
|
||||||
|
exclude: Union[set[str], None] = None,
|
||||||
|
by_alias: bool = False,
|
||||||
|
exclude_unset: bool = False,
|
||||||
|
exclude_defaults: bool = False,
|
||||||
|
exclude_none: bool = False,
|
||||||
|
encoder: Optional[Callable[[Any], Any]] = None,
|
||||||
|
models_as_dict: bool = True,
|
||||||
|
**dumps_kwargs: Any,
|
||||||
|
) -> str:
|
||||||
if self.output_format != OutputFormat.JSON:
|
if self.output_format != OutputFormat.JSON:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"""
|
"""
|
||||||
@@ -45,7 +56,7 @@ class TaskOutput(BaseModel):
|
|||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
|
|
||||||
return json.dumps(self.json_dict)
|
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert json_output and pydantic_output to a dictionary."""
|
"""Convert json_output and pydantic_output to a dictionary."""
|
||||||
|
|||||||
@@ -19,13 +19,13 @@ class BaseAgentTool(BaseTool):
|
|||||||
default_factory=I18N, description="Internationalization settings"
|
default_factory=I18N, description="Internationalization settings"
|
||||||
)
|
)
|
||||||
|
|
||||||
def sanitize_agent_name(self, name: str) -> str:
|
def sanitize_agent_name(self, name: Optional[str]) -> str:
|
||||||
"""
|
"""
|
||||||
Sanitize agent role name by normalizing whitespace and setting to lowercase.
|
Sanitize agent role name by normalizing whitespace and setting to lowercase.
|
||||||
Converts all whitespace (including newlines) to single spaces and removes quotes.
|
Converts all whitespace (including newlines) to single spaces and removes quotes.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
name (str): The agent role name to sanitize
|
name (Optional[str]): The agent role name to sanitize
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
str: The sanitized agent role name, with whitespace normalized,
|
str: The sanitized agent role name, with whitespace normalized,
|
||||||
|
|||||||
@@ -142,7 +142,12 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
# Create model
|
# Create model
|
||||||
schema_name = f"{name.title()}Schema"
|
schema_name = f"{name.title()}Schema"
|
||||||
return create_model(schema_name, **fields)
|
return create_model(
|
||||||
|
schema_name,
|
||||||
|
__base__=BaseModel,
|
||||||
|
__config__=None,
|
||||||
|
**{k: v for k, v in fields.items()}
|
||||||
|
)
|
||||||
|
|
||||||
def _validate_function_signature(self) -> None:
|
def _validate_function_signature(self) -> None:
|
||||||
"""Validate that the function signature matches the args schema."""
|
"""Validate that the function signature matches the args schema."""
|
||||||
@@ -170,7 +175,7 @@ class CrewStructuredTool:
|
|||||||
f"not found in args_schema"
|
f"not found in args_schema"
|
||||||
)
|
)
|
||||||
|
|
||||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
def _parse_args(self, raw_args: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||||
"""Parse and validate the input arguments against the schema.
|
"""Parse and validate the input arguments against the schema.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -178,6 +183,9 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The validated arguments as a dictionary
|
The validated arguments as a dictionary
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the arguments cannot be parsed or fail validation
|
||||||
"""
|
"""
|
||||||
if isinstance(raw_args, str):
|
if isinstance(raw_args, str):
|
||||||
try:
|
try:
|
||||||
@@ -195,8 +203,8 @@ class CrewStructuredTool:
|
|||||||
|
|
||||||
async def ainvoke(
|
async def ainvoke(
|
||||||
self,
|
self,
|
||||||
input: Union[str, dict],
|
input: Union[str, dict[str, Any]],
|
||||||
config: Optional[dict] = None,
|
config: Optional[dict[str, Any]] = None,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Asynchronously invoke the tool.
|
"""Asynchronously invoke the tool.
|
||||||
@@ -229,7 +237,10 @@ class CrewStructuredTool:
|
|||||||
return self.invoke(input_dict)
|
return self.invoke(input_dict)
|
||||||
|
|
||||||
def invoke(
|
def invoke(
|
||||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
self,
|
||||||
|
input: Union[str, dict[str, Any]],
|
||||||
|
config: Optional[dict[str, Any]] = None,
|
||||||
|
**kwargs: Any
|
||||||
) -> Any:
|
) -> Any:
|
||||||
"""Main method for tool execution."""
|
"""Main method for tool execution."""
|
||||||
parsed_args = self._parse_args(input)
|
parsed_args = self._parse_args(input)
|
||||||
|
|||||||
@@ -10,8 +10,24 @@ class Logger(BaseModel):
|
|||||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||||
|
|
||||||
def log(self, level, message, color="bold_yellow"):
|
def log(self, level, message, color="bold_yellow"):
|
||||||
if self.verbose:
|
if self.verbose or level.upper() in ["WARNING", "ERROR"]:
|
||||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||||
self._printer.print(
|
self._printer.print(
|
||||||
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
|
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def debug(self, message: str) -> None:
|
||||||
|
"""Log a debug message if verbose is enabled."""
|
||||||
|
self.log("debug", message, color="bold_blue")
|
||||||
|
|
||||||
|
def info(self, message: str) -> None:
|
||||||
|
"""Log an info message if verbose is enabled."""
|
||||||
|
self.log("info", message, color="bold_green")
|
||||||
|
|
||||||
|
def warning(self, message: str) -> None:
|
||||||
|
"""Log a warning message."""
|
||||||
|
self.log("warning", message, color="bold_yellow")
|
||||||
|
|
||||||
|
def error(self, message: str) -> None:
|
||||||
|
"""Log an error message."""
|
||||||
|
self.log("error", message, color="bold_red")
|
||||||
|
|||||||
@@ -63,16 +63,32 @@ class Prompts(BaseModel):
|
|||||||
for component in components
|
for component in components
|
||||||
if component != "task"
|
if component != "task"
|
||||||
]
|
]
|
||||||
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
system = ""
|
||||||
prompt = prompt_template.replace(
|
if system_template:
|
||||||
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
||||||
)
|
|
||||||
response = response_template.split("{{ .Response }}")[0]
|
|
||||||
prompt = f"{system}\n{prompt}\n{response}"
|
|
||||||
|
|
||||||
|
prompt_text = ""
|
||||||
|
if prompt_template:
|
||||||
|
prompt_text = prompt_template.replace(
|
||||||
|
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
||||||
|
)
|
||||||
|
|
||||||
|
response = ""
|
||||||
|
if response_template:
|
||||||
|
response = response_template.split("{{ .Response }}")[0]
|
||||||
|
|
||||||
|
parts = [p for p in [system, prompt_text, response] if p]
|
||||||
|
prompt = "\n".join(parts) if parts else ""
|
||||||
|
|
||||||
|
# Get agent attributes with default values
|
||||||
|
goal = str(getattr(self.agent, 'goal', '') or '')
|
||||||
|
role = str(getattr(self.agent, 'role', '') or '')
|
||||||
|
backstory = str(getattr(self.agent, 'backstory', '') or '')
|
||||||
|
|
||||||
|
# Replace placeholders with agent attributes
|
||||||
prompt = (
|
prompt = (
|
||||||
prompt.replace("{goal}", self.agent.goal)
|
prompt.replace("{goal}", goal)
|
||||||
.replace("{role}", self.agent.role)
|
.replace("{role}", role)
|
||||||
.replace("{backstory}", self.agent.backstory)
|
.replace("{backstory}", backstory)
|
||||||
)
|
)
|
||||||
return prompt
|
return prompt
|
||||||
|
|||||||
74
src/crewai/utilities/token_process.py
Normal file
74
src/crewai/utilities/token_process.py
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
"""Token processing utility for tracking and managing token usage."""
|
||||||
|
from __future__ import annotations
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from crewai.types.usage_metrics import UsageMetrics
|
||||||
|
|
||||||
|
|
||||||
|
class TokenProcess:
|
||||||
|
"""Handles token processing and tracking for agents."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
"""Initialize the token processor."""
|
||||||
|
self._total_tokens = 0
|
||||||
|
self._prompt_tokens = 0
|
||||||
|
self._completion_tokens = 0
|
||||||
|
self._cached_prompt_tokens = 0
|
||||||
|
self._successful_requests = 0
|
||||||
|
|
||||||
|
def sum_prompt_tokens(self, count: int) -> None:
|
||||||
|
"""Add to prompt token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count (int): Number of prompt tokens to add
|
||||||
|
"""
|
||||||
|
self._prompt_tokens += count
|
||||||
|
self._total_tokens += count
|
||||||
|
|
||||||
|
def sum_completion_tokens(self, count: int) -> None:
|
||||||
|
"""Add to completion token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count (int): Number of completion tokens to add
|
||||||
|
"""
|
||||||
|
self._completion_tokens += count
|
||||||
|
self._total_tokens += count
|
||||||
|
|
||||||
|
def sum_cached_prompt_tokens(self, count: int) -> None:
|
||||||
|
"""Add to cached prompt token count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count (int): Number of cached prompt tokens to add
|
||||||
|
"""
|
||||||
|
self._cached_prompt_tokens += count
|
||||||
|
|
||||||
|
def sum_successful_requests(self, count: int) -> None:
|
||||||
|
"""Add to successful requests count.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
count (int): Number of successful requests to add
|
||||||
|
"""
|
||||||
|
self._successful_requests += count
|
||||||
|
|
||||||
|
def reset(self) -> None:
|
||||||
|
"""Reset all token counts to zero."""
|
||||||
|
self._total_tokens = 0
|
||||||
|
self._prompt_tokens = 0
|
||||||
|
self._completion_tokens = 0
|
||||||
|
self._cached_prompt_tokens = 0
|
||||||
|
self._successful_requests = 0
|
||||||
|
|
||||||
|
def get_summary(self) -> UsageMetrics:
|
||||||
|
"""Get a summary of token usage.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
UsageMetrics: Object containing token usage metrics
|
||||||
|
"""
|
||||||
|
return UsageMetrics(
|
||||||
|
total_tokens=self._total_tokens,
|
||||||
|
prompt_tokens=self._prompt_tokens,
|
||||||
|
cached_prompt_tokens=self._cached_prompt_tokens,
|
||||||
|
completion_tokens=self._completion_tokens,
|
||||||
|
successful_requests=self._successful_requests
|
||||||
|
)
|
||||||
@@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools():
|
|||||||
crew.kickoff()
|
crew.kickoff()
|
||||||
|
|
||||||
# Verify task tools override agent tools
|
# Verify task tools override agent tools
|
||||||
assert len(task.tools) == 1 # AnotherTestTool
|
tools = task.tools or []
|
||||||
assert any(isinstance(tool, AnotherTestTool) for tool in task.tools)
|
assert len(tools) == 1 # AnotherTestTool
|
||||||
assert not any(isinstance(tool, TestTool) for tool in task.tools)
|
assert any(isinstance(tool, AnotherTestTool) for tool in tools)
|
||||||
|
assert not any(isinstance(tool, TestTool) for tool in tools)
|
||||||
|
|
||||||
# Verify agent tools remain unchanged
|
# Verify agent tools remain unchanged
|
||||||
assert len(new_researcher.tools) == 1
|
assert len(new_researcher.tools) == 1
|
||||||
|
|||||||
6
tests/pytest.ini
Normal file
6
tests/pytest.ini
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
[pytest]
|
||||||
|
markers =
|
||||||
|
vcr: Mark a test as using VCR.py for recording/replaying HTTP interactions
|
||||||
|
|
||||||
|
[vcr]
|
||||||
|
record_mode = none
|
||||||
Reference in New Issue
Block a user