mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
12 Commits
bugfix/flo
...
pr-1833
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bfb578d506 | ||
|
|
5d3c34b3ea | ||
|
|
8ec2eb7d72 | ||
|
|
39bdc7e4d4 | ||
|
|
344fa9bbe5 | ||
|
|
0fd0b5c74f | ||
|
|
090d5128cb | ||
|
|
dec255e87a | ||
|
|
f75b07ce82 | ||
|
|
cf2f21cbfb | ||
|
|
c4a401b247 | ||
|
|
b0d545992a |
@@ -1,3 +1,5 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import os
|
||||
import shutil
|
||||
import subprocess
|
||||
@@ -21,6 +23,9 @@ from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.logger import Logger
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
@@ -45,24 +50,113 @@ class Agent(BaseAgent):
|
||||
Each agent has a role, a goal, a backstory, and an optional language model (llm).
|
||||
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
|
||||
|
||||
Attributes:
|
||||
agent_executor: An instance of the CrewAgentExecutor class.
|
||||
role: The role of the agent.
|
||||
goal: The objective of the agent.
|
||||
backstory: The backstory of the agent.
|
||||
knowledge: The knowledge base of the agent.
|
||||
config: Dict representation of agent configuration.
|
||||
llm: The language model that will run the agent.
|
||||
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
|
||||
max_iter: Maximum number of iterations for an agent to execute a task.
|
||||
memory: Whether the agent should have memory or not.
|
||||
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
|
||||
verbose: Whether the agent execution should be in verbose mode.
|
||||
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
|
||||
tools: Tools at agents disposal
|
||||
step_callback: Callback to be executed after each step of the agent execution.
|
||||
knowledge_sources: Knowledge sources for the agent.
|
||||
Args:
|
||||
role (Optional[str]): The role of the agent
|
||||
goal (Optional[str]): The objective of the agent
|
||||
backstory (Optional[str]): The backstory of the agent
|
||||
allow_delegation (bool): Whether the agent can delegate tasks
|
||||
config (Optional[Dict[str, Any]]): Configuration for the agent
|
||||
verbose (bool): Whether to enable verbose output
|
||||
max_rpm (Optional[int]): Maximum requests per minute
|
||||
tools (Optional[List[Any]]): Tools available to the agent
|
||||
llm (Optional[Union[str, Any]]): Language model to use
|
||||
function_calling_llm (Optional[Any]): Language model for tool calling
|
||||
max_iter (Optional[int]): Maximum iterations for task execution
|
||||
memory (bool): Whether the agent should have memory
|
||||
step_callback (Optional[Any]): Callback after each execution step
|
||||
knowledge_sources (Optional[List[BaseKnowledgeSource]]): Knowledge sources
|
||||
"""
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
"extra": "allow",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
role: Optional[str] = None,
|
||||
goal: Optional[str] = None,
|
||||
backstory: Optional[str] = None,
|
||||
allow_delegation: bool = False,
|
||||
config: Optional[Dict[str, Any]] = None,
|
||||
verbose: bool = False,
|
||||
max_rpm: Optional[int] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
llm: Optional[Union[str, LLM, Any]] = None,
|
||||
function_calling_llm: Optional[Any] = None,
|
||||
max_iter: Optional[int] = None,
|
||||
memory: bool = True,
|
||||
step_callback: Optional[Any] = None,
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = None,
|
||||
**kwargs
|
||||
) -> None:
|
||||
"""Initialize an Agent with the given parameters."""
|
||||
# Process tools before passing to parent
|
||||
processed_tools = []
|
||||
if tools:
|
||||
from crewai.tools import BaseTool
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
processed_tools.append(tool)
|
||||
elif callable(tool):
|
||||
# Convert function to BaseTool
|
||||
processed_tools.append(Tool.from_function(tool))
|
||||
else:
|
||||
raise ValueError(f"Tool {tool} must be either a BaseTool instance or a callable")
|
||||
|
||||
# Process LLM before passing to parent
|
||||
processed_llm = None
|
||||
if isinstance(llm, str):
|
||||
processed_llm = LLM(model=llm)
|
||||
elif isinstance(llm, LLM):
|
||||
processed_llm = llm
|
||||
elif llm is not None and hasattr(llm, 'model') and hasattr(llm, 'temperature'):
|
||||
# Handle ChatOpenAI and similar objects
|
||||
model_name = getattr(llm, 'model', None)
|
||||
if model_name is not None:
|
||||
if not isinstance(model_name, str):
|
||||
model_name = str(model_name)
|
||||
processed_llm = LLM(
|
||||
model=model_name,
|
||||
temperature=getattr(llm, 'temperature', None),
|
||||
api_key=getattr(llm, 'api_key', None),
|
||||
base_url=getattr(llm, 'base_url', None)
|
||||
)
|
||||
# If no valid LLM configuration found, leave as None for post_init_setup
|
||||
|
||||
# Initialize all fields in a dict
|
||||
init_dict = {
|
||||
"role": role,
|
||||
"goal": goal,
|
||||
"backstory": backstory,
|
||||
"allow_delegation": allow_delegation,
|
||||
"config": config,
|
||||
"verbose": verbose,
|
||||
"max_rpm": max_rpm,
|
||||
"tools": processed_tools,
|
||||
"max_iter": max_iter if max_iter is not None else 25,
|
||||
"function_calling_llm": function_calling_llm,
|
||||
"step_callback": step_callback,
|
||||
"knowledge_sources": knowledge_sources,
|
||||
**kwargs
|
||||
}
|
||||
|
||||
# Initialize base model with all fields
|
||||
super().__init__(**init_dict)
|
||||
|
||||
# Store original values for interpolation
|
||||
self._original_role = role
|
||||
self._original_goal = goal
|
||||
self._original_backstory = backstory
|
||||
|
||||
# Set LLM after base initialization to ensure proper model handling
|
||||
self.llm = processed_llm
|
||||
|
||||
# Initialize private attributes
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm:
|
||||
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||
self._token_process = TokenProcess()
|
||||
|
||||
_times_executed: int = PrivateAttr(default=0)
|
||||
max_execution_time: Optional[int] = Field(
|
||||
@@ -138,21 +232,15 @@ class Agent(BaseAgent):
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
self.agent_ops_agent_name = self.role or "agent"
|
||||
unaccepted_attributes = [
|
||||
"AWS_ACCESS_KEY_ID",
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
# If it's a string, create an LLM instance
|
||||
self.llm = LLM(model=self.llm)
|
||||
elif isinstance(self.llm, LLM):
|
||||
# If it's already an LLM instance, keep it as is
|
||||
pass
|
||||
elif self.llm is None:
|
||||
# Handle LLM initialization if not already done
|
||||
if self.llm is None:
|
||||
# Determine the model name from environment variables or use default
|
||||
model_name = (
|
||||
os.environ.get("OPENAI_MODEL_NAME")
|
||||
@@ -190,9 +278,71 @@ class Agent(BaseAgent):
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
try:
|
||||
# Create a new dictionary for properly typed parameters
|
||||
typed_params = {}
|
||||
|
||||
# Convert and validate values based on parameter type
|
||||
if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
|
||||
if value is not None:
|
||||
try:
|
||||
typed_params[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']:
|
||||
if value is not None:
|
||||
try:
|
||||
typed_params[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif key == 'logit_bias' and isinstance(value, str):
|
||||
try:
|
||||
bias_dict = {}
|
||||
for pair in value.split(','):
|
||||
token_id, bias = pair.split(':')
|
||||
bias_dict[int(token_id.strip())] = float(bias.strip())
|
||||
typed_params[key] = bias_dict
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
elif key == 'response_format' and isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
typed_params[key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif key == 'logprobs':
|
||||
if value is not None:
|
||||
typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value)
|
||||
elif key == 'callbacks':
|
||||
typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value
|
||||
elif key == 'stop':
|
||||
typed_params[key] = [value] if isinstance(value, str) else value
|
||||
elif key in ['model', 'base_url', 'api_version', 'api_key']:
|
||||
typed_params[key] = value
|
||||
|
||||
# Update llm_params with properly typed values
|
||||
if typed_params:
|
||||
llm_params.update(typed_params)
|
||||
except (ValueError, AttributeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
# Create LLM instance with properly typed parameters
|
||||
valid_params = {
|
||||
'model', 'timeout', 'temperature', 'top_p', 'n', 'stop',
|
||||
'max_completion_tokens', 'max_tokens', 'presence_penalty',
|
||||
'frequency_penalty', 'logit_bias', 'response_format',
|
||||
'seed', 'logprobs', 'top_logprobs', 'base_url',
|
||||
'api_version', 'api_key', 'callbacks'
|
||||
}
|
||||
|
||||
# Filter out None values and invalid parameters
|
||||
filtered_params = {}
|
||||
for k, v in llm_params.items():
|
||||
if k in valid_params and v is not None:
|
||||
filtered_params[k] = v
|
||||
|
||||
# Create LLM instance with properly typed parameters
|
||||
self.llm = LLM(**filtered_params)
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
@@ -239,7 +389,7 @@ class Agent(BaseAgent):
|
||||
def _set_knowledge(self):
|
||||
try:
|
||||
if self.knowledge_sources:
|
||||
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
|
||||
knowledge_agent_name = f"{(self.role or 'agent').replace(' ', '_')}"
|
||||
if isinstance(self.knowledge_sources, list) and all(
|
||||
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
|
||||
):
|
||||
@@ -384,6 +534,32 @@ class Agent(BaseAgent):
|
||||
self.response_template.split("{{ .Response }}")[1].strip()
|
||||
)
|
||||
|
||||
# Ensure LLM is initialized with proper error handling
|
||||
try:
|
||||
if not self.llm:
|
||||
self.llm = LLM(model="gpt-4")
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.debug("Initialized default LLM with gpt-4 model")
|
||||
except Exception as e:
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.error(f"Failed to initialize LLM: {str(e)}")
|
||||
raise
|
||||
|
||||
# Create token callback with proper error handling
|
||||
try:
|
||||
token_callback = None
|
||||
if hasattr(self, '_token_process'):
|
||||
token_callback = TokenCalcHandler(self._token_process)
|
||||
except Exception as e:
|
||||
if hasattr(self, '_logger'):
|
||||
self._logger.warning(f"Failed to create token callback: {str(e)}")
|
||||
token_callback = None
|
||||
|
||||
# Initialize callbacks list
|
||||
executor_callbacks = []
|
||||
if token_callback:
|
||||
executor_callbacks.append(token_callback)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
task=task,
|
||||
@@ -401,9 +577,9 @@ class Agent(BaseAgent):
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
respect_context_window=self.respect_context_window,
|
||||
request_within_rpm_limit=(
|
||||
self._rpm_controller.check_or_wait if self._rpm_controller else None
|
||||
self._rpm_controller.check_or_wait if (hasattr(self, '_rpm_controller') and self._rpm_controller is not None) else None
|
||||
),
|
||||
callbacks=[TokenCalcHandler(self._token_process)],
|
||||
callbacks=executor_callbacks,
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: List[BaseAgent]):
|
||||
|
||||
@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
|
||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import Tool
|
||||
from crewai.utilities import I18N, Logger, RPMController
|
||||
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
|
||||
formatting_errors: int = Field(
|
||||
default=0, description="Number of formatting errors."
|
||||
)
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Objective of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
role: Optional[str] = Field(default=None, description="Role of the agent")
|
||||
goal: Optional[str] = Field(default=None, description="Objective of the agent")
|
||||
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
|
||||
config: Optional[Dict[str, Any]] = Field(
|
||||
description="Configuration for the agent", default=None, exclude=True
|
||||
)
|
||||
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
|
||||
max_tokens: Optional[int] = Field(
|
||||
default=None, description="Maximum number of tokens for the agent's execution."
|
||||
)
|
||||
function_calling_llm: Optional[Any] = Field(
|
||||
default=None, description="Language model for function calling."
|
||||
)
|
||||
step_callback: Optional[Any] = Field(
|
||||
default=None, description="Callback for execution steps."
|
||||
)
|
||||
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
|
||||
default=None, description="Knowledge sources for the agent."
|
||||
)
|
||||
|
||||
model_config = {
|
||||
"arbitrary_types_allowed": True,
|
||||
"extra": "allow", # Allow extra fields in constructor
|
||||
}
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def process_model_config(cls, values):
|
||||
"""Process configuration values before model initialization."""
|
||||
return process_config(values, cls)
|
||||
|
||||
@field_validator("tools")
|
||||
@classmethod
|
||||
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
|
||||
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
|
||||
"""Validate and process the tools provided to the agent.
|
||||
|
||||
This method ensures that each tool is either an instance of BaseTool
|
||||
or an object with 'name', 'func', and 'description' attributes. If the
|
||||
tool meets these criteria, it is processed and added to the list of
|
||||
tools. Otherwise, a ValueError is raised.
|
||||
This method ensures that each tool is either an instance of BaseTool,
|
||||
a function decorated with @tool, or an object with 'name', 'func',
|
||||
and 'description' attributes. If the tool meets these criteria, it is
|
||||
processed and added to the list of tools. Otherwise, a ValueError is raised.
|
||||
"""
|
||||
if not tools:
|
||||
return []
|
||||
|
||||
processed_tools = []
|
||||
for tool in tools:
|
||||
if isinstance(tool, BaseTool):
|
||||
processed_tools.append(tool)
|
||||
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
|
||||
# Handle @tool decorated functions
|
||||
processed_tools.append(Tool.from_function(tool))
|
||||
elif (
|
||||
hasattr(tool, "name")
|
||||
and hasattr(tool, "func")
|
||||
@@ -157,31 +179,54 @@ class BaseAgent(ABC, BaseModel):
|
||||
):
|
||||
# Tool has the required attributes, create a Tool instance
|
||||
processed_tools.append(Tool.from_langchain(tool))
|
||||
else:
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Invalid tool type: {type(tool)}. "
|
||||
"Tool must be an instance of BaseTool or "
|
||||
"an object with 'name', 'func', and 'description' attributes."
|
||||
"Tool must be an instance of BaseTool, a @tool decorated function, "
|
||||
"or an object with 'name', 'func', and 'description' attributes."
|
||||
)
|
||||
return processed_tools
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_and_set_attributes(self):
|
||||
# Validate required fields
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if getattr(self, field) is None:
|
||||
raise ValueError(
|
||||
f"{field} must be provided either directly or through config"
|
||||
)
|
||||
"""Validate and set attributes for the agent.
|
||||
|
||||
This method ensures that attributes are properly set and initialized,
|
||||
either from direct parameters or configuration.
|
||||
"""
|
||||
# Store original values for interpolation
|
||||
self._original_role = self.role
|
||||
self._original_goal = self.goal
|
||||
self._original_backstory = self.backstory
|
||||
|
||||
# Set private attributes
|
||||
# Process config if provided
|
||||
if self.config:
|
||||
config_data = self.config
|
||||
if isinstance(config_data, str):
|
||||
import json
|
||||
try:
|
||||
config_data = json.loads(config_data)
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError("Invalid JSON in config")
|
||||
|
||||
# Update fields from config if they're None
|
||||
for field in ["role", "goal", "backstory"]:
|
||||
if field in config_data and getattr(self, field) is None:
|
||||
setattr(self, field, config_data[field])
|
||||
|
||||
# Set default values for required fields if they're still None
|
||||
self.role = self.role or "Assistant"
|
||||
self.goal = self.goal or "Help the user accomplish their tasks"
|
||||
self.backstory = self.backstory or "I am an AI assistant ready to help"
|
||||
|
||||
# Initialize tools handler if not set
|
||||
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
|
||||
self.tools_handler = ToolsHandler()
|
||||
|
||||
# Initialize logger and rpm controller
|
||||
self._logger = Logger(verbose=self.verbose)
|
||||
if self.max_rpm and not self._rpm_controller:
|
||||
self._rpm_controller = RPMController(
|
||||
max_rpm=self.max_rpm, logger=self._logger
|
||||
)
|
||||
if not self._token_process:
|
||||
self._token_process = TokenProcess()
|
||||
if self.max_rpm:
|
||||
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
|
||||
|
||||
return self
|
||||
|
||||
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
|
||||
@property
|
||||
def key(self):
|
||||
source = [
|
||||
self._original_role or self.role,
|
||||
self._original_goal or self.goal,
|
||||
self._original_backstory or self.backstory,
|
||||
str(self._original_role or self.role or ""),
|
||||
str(self._original_goal or self.goal or ""),
|
||||
str(self._original_backstory or self.backstory or ""),
|
||||
]
|
||||
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
|
||||
|
||||
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
|
||||
"tools_handler",
|
||||
"cache_handler",
|
||||
"llm",
|
||||
"function_calling_llm",
|
||||
}
|
||||
|
||||
# Copy llm and clear callbacks
|
||||
existing_llm = shallow_copy(self.llm)
|
||||
# Copy LLMs and clear callbacks
|
||||
existing_llm = shallow_copy(self.llm) if self.llm else None
|
||||
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
|
||||
|
||||
# Create base data
|
||||
copied_data = self.model_dump(exclude=exclude)
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
|
||||
|
||||
# Create new instance with copied data
|
||||
copied_agent = type(self)(
|
||||
**copied_data,
|
||||
llm=existing_llm,
|
||||
function_calling_llm=existing_function_calling_llm,
|
||||
tools=self.tools
|
||||
)
|
||||
|
||||
# Copy private attributes
|
||||
copied_agent._original_role = self._original_role
|
||||
copied_agent._original_goal = self._original_goal
|
||||
copied_agent._original_backstory = self._original_backstory
|
||||
|
||||
return copied_agent
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
self._original_role = self.role
|
||||
self._original_role = self.role or ""
|
||||
if self._original_goal is None:
|
||||
self._original_goal = self.goal
|
||||
self._original_goal = self.goal or ""
|
||||
if self._original_backstory is None:
|
||||
self._original_backstory = self.backstory
|
||||
self._original_backstory = self.backstory or ""
|
||||
|
||||
if inputs:
|
||||
self.role = self._original_role.format(**inputs)
|
||||
self.goal = self._original_goal.format(**inputs)
|
||||
self.backstory = self._original_backstory.format(**inputs)
|
||||
self.role = self._original_role.format(**inputs) if self._original_role else None
|
||||
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
|
||||
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
|
||||
|
||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||
"""Set the cache handler for the agent.
|
||||
|
||||
@@ -82,16 +82,17 @@ class CrewAgentExecutorMixin:
|
||||
)
|
||||
self.crew._long_term_memory.save(long_term_memory)
|
||||
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join(
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
if hasattr(evaluation, 'entities') and evaluation.entities:
|
||||
for entity in evaluation.entities:
|
||||
entity_memory = EntityMemoryItem(
|
||||
name=entity.name,
|
||||
type=entity.type,
|
||||
description=entity.description,
|
||||
relationships="\n".join(
|
||||
[f"- {r}" for r in entity.relationships]
|
||||
),
|
||||
)
|
||||
self.crew._entity_memory.save(entity_memory)
|
||||
except AttributeError as e:
|
||||
print(f"Missing attributes for long term memory: {e}")
|
||||
pass
|
||||
|
||||
@@ -68,7 +68,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools
|
||||
self.step_callback = step_callback
|
||||
self.use_stop_words = self.llm.supports_stop_words()
|
||||
self.use_stop_words = self.llm.supports_stop_words() if self.llm else False
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
@@ -147,7 +147,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
# Directly append the result to the messages if the
|
||||
# tool is "Add image to content" in case of multimodal
|
||||
# agents
|
||||
if formatted_answer.tool == self._i18n.tools("add_image")["name"]:
|
||||
add_image_tool_name = self._i18n.tools("add_image")
|
||||
if add_image_tool_name and formatted_answer.tool == add_image_tool_name:
|
||||
self.messages.append(tool_result.result)
|
||||
continue
|
||||
|
||||
@@ -214,13 +215,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.agent.verbose or (
|
||||
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
||||
):
|
||||
agent_role = self.agent.role.split("\n")[0]
|
||||
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
|
||||
self._printer.print(
|
||||
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
|
||||
)
|
||||
self._printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
|
||||
)
|
||||
if self.task and self.task.description:
|
||||
self._printer.print(
|
||||
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
|
||||
)
|
||||
|
||||
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
|
||||
if self.agent is None:
|
||||
@@ -228,7 +230,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
if self.agent.verbose or (
|
||||
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
|
||||
):
|
||||
agent_role = self.agent.role.split("\n")[0]
|
||||
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
|
||||
formatted_json = json.dumps(
|
||||
|
||||
@@ -5,6 +5,7 @@ from pathlib import Path
|
||||
|
||||
import click
|
||||
import requests
|
||||
from typing import Any
|
||||
|
||||
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
|
||||
|
||||
@@ -192,7 +193,7 @@ def download_data(response):
|
||||
data_chunks = []
|
||||
with click.progressbar(
|
||||
length=total_size, label="Downloading", show_pos=True
|
||||
) as progress_bar:
|
||||
) as progress_bar: # type: Any
|
||||
for chunk in response.iter_content(block_size):
|
||||
if chunk:
|
||||
data_chunks.append(chunk)
|
||||
|
||||
@@ -6,6 +6,8 @@ from concurrent.futures import Future
|
||||
from hashlib import md5
|
||||
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -728,7 +730,7 @@ class Crew(BaseModel):
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
self._log_task_start(task, agent_to_use.role if agent_to_use and agent_to_use.role else "")
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = self._handle_conditional_task(
|
||||
@@ -794,8 +796,8 @@ class Crew(BaseModel):
|
||||
return None
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
self, agent: BaseAgent, task: Task, tools: List[Union[Tool, BaseTool]]
|
||||
) -> List[Union[Tool, BaseTool]]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if agent.allow_delegation:
|
||||
if self.process == Process.hierarchical:
|
||||
@@ -824,8 +826,8 @@ class Crew(BaseModel):
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self, existing_tools: List[Tool], new_tools: List[Tool]
|
||||
) -> List[Tool]:
|
||||
self, existing_tools: List[Union[Tool, BaseTool]], new_tools: List[Union[Tool, BaseTool]]
|
||||
) -> List[Union[Tool, BaseTool]]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
return existing_tools
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
@@ -23,14 +23,25 @@ class CrewOutput(BaseModel):
|
||||
)
|
||||
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(
|
||||
self,
|
||||
*,
|
||||
include: Union[set[str], None] = None,
|
||||
exclude: Union[set[str], None] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
encoder: Optional[Callable[[Any], Any]] = None,
|
||||
models_as_dict: bool = True,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
if self.tasks_output[-1].output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -106,7 +106,12 @@ class FlowPlot:
|
||||
|
||||
# Add nodes to the network
|
||||
try:
|
||||
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
|
||||
add_nodes_to_network(
|
||||
net,
|
||||
flow=self.flow,
|
||||
pos=node_positions,
|
||||
node_styles=self.node_styles
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@ import warnings
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
@@ -93,10 +95,33 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
class LLM(BaseModel):
|
||||
model: str = "gpt-4" # Set default model
|
||||
timeout: Optional[Union[float, int]] = None
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
max_completion_tokens: Optional[int] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[int, float]] = None
|
||||
response_format: Optional[Dict[str, Any]] = None
|
||||
seed: Optional[int] = None
|
||||
logprobs: Optional[bool] = None
|
||||
top_logprobs: Optional[int] = None
|
||||
base_url: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
callbacks: Optional[List[Any]] = None
|
||||
context_window_size: Optional[int] = None
|
||||
kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
model: Optional[Union[str, 'LLM']] = "gpt-4",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
@@ -114,118 +139,427 @@ class LLM:
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stop = stop
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
context_window_size: Optional[int] = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
# Initialize with default values
|
||||
init_dict = {
|
||||
"model": model if isinstance(model, str) else "gpt-4",
|
||||
"timeout": timeout,
|
||||
"temperature": temperature,
|
||||
"top_p": top_p,
|
||||
"n": n,
|
||||
"stop": stop,
|
||||
"max_completion_tokens": max_completion_tokens,
|
||||
"max_tokens": max_tokens,
|
||||
"presence_penalty": presence_penalty,
|
||||
"frequency_penalty": frequency_penalty,
|
||||
"logit_bias": logit_bias,
|
||||
"response_format": response_format,
|
||||
"seed": seed,
|
||||
"logprobs": logprobs,
|
||||
"top_logprobs": top_logprobs,
|
||||
"base_url": base_url,
|
||||
"api_version": api_version,
|
||||
"api_key": api_key,
|
||||
"callbacks": callbacks,
|
||||
"context_window_size": context_window_size,
|
||||
"kwargs": kwargs,
|
||||
}
|
||||
super().__init__(**init_dict)
|
||||
|
||||
# Initialize model with default value
|
||||
self.model = "gpt-4" # Default fallback
|
||||
|
||||
# Extract and validate model name
|
||||
if isinstance(model, LLM):
|
||||
# Extract and validate model name from LLM instance
|
||||
if hasattr(model, 'model'):
|
||||
if isinstance(model.model, str):
|
||||
self.model = model.model
|
||||
else:
|
||||
# Try to extract string model name from nested LLM
|
||||
if isinstance(model.model, LLM):
|
||||
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||
else:
|
||||
# Extract and validate model name for non-LLM instances
|
||||
if not isinstance(model, str):
|
||||
if self.logger:
|
||||
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||
if model is not None:
|
||||
if hasattr(model, 'model_name'):
|
||||
model_name = getattr(model, 'model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
elif hasattr(model, 'model'):
|
||||
model_attr = getattr(model, 'model', None)
|
||||
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||
elif hasattr(model, '_model_name'):
|
||||
model_name = getattr(model, '_model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback
|
||||
if self.logger:
|
||||
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback for None
|
||||
if self.logger:
|
||||
self.logger.warning("Model is None, using default: gpt-4")
|
||||
else:
|
||||
self.model = str(model) # Ensure it's a string
|
||||
|
||||
# If model is an LLM instance, copy its configuration
|
||||
if isinstance(model, LLM):
|
||||
# Extract and validate model name first
|
||||
if hasattr(model, 'model'):
|
||||
if isinstance(model.model, str):
|
||||
self.model = model.model
|
||||
else:
|
||||
# Try to extract string model name from nested LLM
|
||||
if isinstance(model.model, LLM):
|
||||
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
|
||||
|
||||
# Copy other configuration
|
||||
self.timeout = model.timeout
|
||||
self.temperature = model.temperature
|
||||
self.top_p = model.top_p
|
||||
self.n = model.n
|
||||
self.stop = model.stop
|
||||
self.max_completion_tokens = model.max_completion_tokens
|
||||
self.max_tokens = model.max_tokens
|
||||
self.presence_penalty = model.presence_penalty
|
||||
self.frequency_penalty = model.frequency_penalty
|
||||
self.logit_bias = model.logit_bias
|
||||
self.response_format = model.response_format
|
||||
self.seed = model.seed
|
||||
self.logprobs = model.logprobs
|
||||
self.top_logprobs = model.top_logprobs
|
||||
self.base_url = model.base_url
|
||||
self.api_version = model.api_version
|
||||
self.api_key = model.api_key
|
||||
self.callbacks = model.callbacks
|
||||
self.context_window_size = model.context_window_size
|
||||
self.kwargs = model.kwargs
|
||||
|
||||
# Final validation of model name
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
|
||||
else:
|
||||
# Extract and validate model name for non-LLM instances
|
||||
if not isinstance(model, str):
|
||||
if self.logger:
|
||||
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
|
||||
if model is not None:
|
||||
if hasattr(model, 'model_name'):
|
||||
model_name = getattr(model, 'model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
elif hasattr(model, 'model'):
|
||||
model_attr = getattr(model, 'model', None)
|
||||
self.model = str(model_attr) if model_attr is not None else "gpt-4"
|
||||
elif hasattr(model, '_model_name'):
|
||||
model_name = getattr(model, '_model_name', None)
|
||||
self.model = str(model_name) if model_name is not None else "gpt-4"
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback
|
||||
if self.logger:
|
||||
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
|
||||
else:
|
||||
self.model = "gpt-4" # Default fallback for None
|
||||
if self.logger:
|
||||
self.logger.warning("Model is None, using default: gpt-4")
|
||||
else:
|
||||
self.model = str(model) # Ensure it's a string
|
||||
|
||||
# Final validation
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
|
||||
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.stop = stop
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
|
||||
# Ensure model is a string after initialization
|
||||
if not isinstance(self.model, str):
|
||||
self.model = "gpt-4"
|
||||
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
callbacks: Optional[List[Any]] = None
|
||||
) -> str:
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
# Store original model to restore later
|
||||
original_model = self.model
|
||||
|
||||
try:
|
||||
# Ensure model is a string before making the call
|
||||
if not isinstance(self.model, str):
|
||||
if self.logger:
|
||||
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
|
||||
if isinstance(self.model, LLM):
|
||||
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
|
||||
elif hasattr(self.model, 'model_name'):
|
||||
self.model = str(self.model.model_name)
|
||||
elif hasattr(self.model, 'model'):
|
||||
if isinstance(self.model.model, str):
|
||||
self.model = str(self.model.model)
|
||||
elif hasattr(self.model.model, 'model_name'):
|
||||
self.model = str(self.model.model.model_name)
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
|
||||
else:
|
||||
self.model = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
|
||||
|
||||
# Create base params with validated model name
|
||||
# Extract model name string
|
||||
model_name = None
|
||||
if isinstance(self.model, str):
|
||||
model_name = self.model
|
||||
elif hasattr(self.model, 'model_name'):
|
||||
model_name = str(self.model.model_name)
|
||||
elif hasattr(self.model, 'model'):
|
||||
if isinstance(self.model.model, str):
|
||||
model_name = str(self.model.model)
|
||||
elif hasattr(self.model.model, 'model_name'):
|
||||
model_name = str(self.model.model.model_name)
|
||||
|
||||
if not model_name:
|
||||
model_name = "gpt-4"
|
||||
if self.logger:
|
||||
self.logger.warning("Could not extract model name, using default: gpt-4")
|
||||
|
||||
params = {
|
||||
"model": self.model,
|
||||
"model": model_name,
|
||||
"messages": messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
"stream": False,
|
||||
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
|
||||
"api_base": self.base_url,
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
if self.logger:
|
||||
self.logger.debug(f"Using model parameters: {params}")
|
||||
|
||||
# Add API configuration if available
|
||||
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
|
||||
if api_key:
|
||||
params["api_key"] = api_key
|
||||
|
||||
# Try to get supported parameters for the model
|
||||
try:
|
||||
supported_params = get_supported_openai_params(self.model)
|
||||
optional_params = {}
|
||||
|
||||
if supported_params:
|
||||
param_mapping = {
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs
|
||||
}
|
||||
|
||||
# Only add parameters that are supported and not None
|
||||
optional_params = {
|
||||
k: v for k, v in param_mapping.items()
|
||||
if k in supported_params and v is not None
|
||||
}
|
||||
if "logprobs" in supported_params and self.logprobs is not None:
|
||||
optional_params["logprobs"] = self.logprobs
|
||||
if "top_logprobs" in supported_params and self.top_logprobs is not None:
|
||||
optional_params["top_logprobs"] = self.top_logprobs
|
||||
except Exception as e:
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
|
||||
# If we can't get supported params, just add non-None parameters
|
||||
param_mapping = {
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs
|
||||
}
|
||||
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
|
||||
|
||||
# Update params with optional parameters
|
||||
params.update(optional_params)
|
||||
|
||||
# Add API endpoint configuration if available
|
||||
if self.base_url:
|
||||
params["api_base"] = self.base_url
|
||||
if self.api_version:
|
||||
params["api_version"] = self.api_version
|
||||
|
||||
# Final validation of model parameter
|
||||
if not isinstance(params["model"], str):
|
||||
if self.logger:
|
||||
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
|
||||
params["model"] = "gpt-4"
|
||||
|
||||
# Update params with non-None optional parameters
|
||||
params.update({k: v for k, v in optional_params.items() if v is not None})
|
||||
|
||||
# Add any additional kwargs
|
||||
if self.kwargs:
|
||||
params.update(self.kwargs)
|
||||
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
content = response["choices"][0]["message"]["content"]
|
||||
|
||||
# Extract usage metrics
|
||||
usage = response.get("usage", {})
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "update_token_usage"):
|
||||
callback.update_token_usage(usage)
|
||||
|
||||
return content
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
finally:
|
||||
# Always restore the original model object
|
||||
self.model = original_model
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
bool: True if the model supports function calling, False otherwise
|
||||
"""
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return "response_format" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
Returns False if the LLM is not properly initialized."""
|
||||
if not hasattr(self, 'model') or self.model is None:
|
||||
return False
|
||||
try:
|
||||
params = get_supported_openai_params(model=self.model)
|
||||
return "stop" in params
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to get supported params: {str(e)}")
|
||||
if self.logger:
|
||||
self.logger.error(f"Failed to get supported params: {str(e)}")
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model.
|
||||
|
||||
Returns:
|
||||
int: The context window size in tokens
|
||||
"""
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
if self.context_window_size is not None and self.context_window_size != 0:
|
||||
return int(self.context_window_size)
|
||||
|
||||
self.context_window_size = int(
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
||||
)
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
if isinstance(self.model, str):
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if self.model.startswith(key):
|
||||
window_size = value
|
||||
break
|
||||
|
||||
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
|
||||
"""Set callbacks for the LLM.
|
||||
|
||||
Args:
|
||||
callbacks: Optional list of callback functions. If None, no callbacks will be set.
|
||||
"""
|
||||
if callbacks is not None:
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm.success_callback.remove(callback)
|
||||
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
for callback in litellm._async_success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
litellm._async_success_callback.remove(callback)
|
||||
|
||||
litellm.callbacks = callbacks
|
||||
litellm.callbacks = callbacks
|
||||
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
|
||||
@@ -269,7 +269,9 @@ class Task(BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def check_tools(self):
|
||||
"""Check if the tools are set."""
|
||||
if not self.tools and self.agent and self.agent.tools:
|
||||
if self.agent and self.agent.tools:
|
||||
if self.tools is None:
|
||||
self.tools = []
|
||||
self.tools.extend(self.agent.tools)
|
||||
return self
|
||||
|
||||
@@ -348,7 +350,8 @@ class Task(BaseModel):
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
if agent and agent.role:
|
||||
self.processed_by_agents.add(agent.role)
|
||||
|
||||
result = agent.execute_task(
|
||||
task=self,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
import json
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Callable, Dict, Optional, Union
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
@@ -34,8 +34,19 @@ class TaskOutput(BaseModel):
|
||||
self.summary = f"{excerpt}..."
|
||||
return self
|
||||
|
||||
@property
|
||||
def json(self) -> Optional[str]:
|
||||
def json(
|
||||
self,
|
||||
*,
|
||||
include: Union[set[str], None] = None,
|
||||
exclude: Union[set[str], None] = None,
|
||||
by_alias: bool = False,
|
||||
exclude_unset: bool = False,
|
||||
exclude_defaults: bool = False,
|
||||
exclude_none: bool = False,
|
||||
encoder: Optional[Callable[[Any], Any]] = None,
|
||||
models_as_dict: bool = True,
|
||||
**dumps_kwargs: Any,
|
||||
) -> str:
|
||||
if self.output_format != OutputFormat.JSON:
|
||||
raise ValueError(
|
||||
"""
|
||||
@@ -45,7 +56,7 @@ class TaskOutput(BaseModel):
|
||||
"""
|
||||
)
|
||||
|
||||
return json.dumps(self.json_dict)
|
||||
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
"""Convert json_output and pydantic_output to a dictionary."""
|
||||
|
||||
@@ -19,13 +19,13 @@ class BaseAgentTool(BaseTool):
|
||||
default_factory=I18N, description="Internationalization settings"
|
||||
)
|
||||
|
||||
def sanitize_agent_name(self, name: str) -> str:
|
||||
def sanitize_agent_name(self, name: Optional[str]) -> str:
|
||||
"""
|
||||
Sanitize agent role name by normalizing whitespace and setting to lowercase.
|
||||
Converts all whitespace (including newlines) to single spaces and removes quotes.
|
||||
|
||||
Args:
|
||||
name (str): The agent role name to sanitize
|
||||
name (Optional[str]): The agent role name to sanitize
|
||||
|
||||
Returns:
|
||||
str: The sanitized agent role name, with whitespace normalized,
|
||||
|
||||
@@ -142,7 +142,12 @@ class CrewStructuredTool:
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields)
|
||||
return create_model(
|
||||
schema_name,
|
||||
__base__=BaseModel,
|
||||
__config__=None,
|
||||
**{k: v for k, v in fields.items()}
|
||||
)
|
||||
|
||||
def _validate_function_signature(self) -> None:
|
||||
"""Validate that the function signature matches the args schema."""
|
||||
@@ -170,7 +175,7 @@ class CrewStructuredTool:
|
||||
f"not found in args_schema"
|
||||
)
|
||||
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
def _parse_args(self, raw_args: Union[str, dict[str, Any]]) -> dict[str, Any]:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
Args:
|
||||
@@ -178,6 +183,9 @@ class CrewStructuredTool:
|
||||
|
||||
Returns:
|
||||
The validated arguments as a dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If the arguments cannot be parsed or fail validation
|
||||
"""
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
@@ -195,8 +203,8 @@ class CrewStructuredTool:
|
||||
|
||||
async def ainvoke(
|
||||
self,
|
||||
input: Union[str, dict],
|
||||
config: Optional[dict] = None,
|
||||
input: Union[str, dict[str, Any]],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Asynchronously invoke the tool.
|
||||
@@ -229,7 +237,10 @@ class CrewStructuredTool:
|
||||
return self.invoke(input_dict)
|
||||
|
||||
def invoke(
|
||||
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
||||
self,
|
||||
input: Union[str, dict[str, Any]],
|
||||
config: Optional[dict[str, Any]] = None,
|
||||
**kwargs: Any
|
||||
) -> Any:
|
||||
"""Main method for tool execution."""
|
||||
parsed_args = self._parse_args(input)
|
||||
|
||||
@@ -10,8 +10,24 @@ class Logger(BaseModel):
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
def log(self, level, message, color="bold_yellow"):
|
||||
if self.verbose:
|
||||
if self.verbose or level.upper() in ["WARNING", "ERROR"]:
|
||||
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
|
||||
self._printer.print(
|
||||
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
|
||||
)
|
||||
|
||||
def debug(self, message: str) -> None:
|
||||
"""Log a debug message if verbose is enabled."""
|
||||
self.log("debug", message, color="bold_blue")
|
||||
|
||||
def info(self, message: str) -> None:
|
||||
"""Log an info message if verbose is enabled."""
|
||||
self.log("info", message, color="bold_green")
|
||||
|
||||
def warning(self, message: str) -> None:
|
||||
"""Log a warning message."""
|
||||
self.log("warning", message, color="bold_yellow")
|
||||
|
||||
def error(self, message: str) -> None:
|
||||
"""Log an error message."""
|
||||
self.log("error", message, color="bold_red")
|
||||
|
||||
@@ -63,16 +63,32 @@ class Prompts(BaseModel):
|
||||
for component in components
|
||||
if component != "task"
|
||||
]
|
||||
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
||||
prompt = prompt_template.replace(
|
||||
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
||||
)
|
||||
response = response_template.split("{{ .Response }}")[0]
|
||||
prompt = f"{system}\n{prompt}\n{response}"
|
||||
system = ""
|
||||
if system_template:
|
||||
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
|
||||
|
||||
prompt_text = ""
|
||||
if prompt_template:
|
||||
prompt_text = prompt_template.replace(
|
||||
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
|
||||
)
|
||||
|
||||
response = ""
|
||||
if response_template:
|
||||
response = response_template.split("{{ .Response }}")[0]
|
||||
|
||||
parts = [p for p in [system, prompt_text, response] if p]
|
||||
prompt = "\n".join(parts) if parts else ""
|
||||
|
||||
# Get agent attributes with default values
|
||||
goal = str(getattr(self.agent, 'goal', '') or '')
|
||||
role = str(getattr(self.agent, 'role', '') or '')
|
||||
backstory = str(getattr(self.agent, 'backstory', '') or '')
|
||||
|
||||
# Replace placeholders with agent attributes
|
||||
prompt = (
|
||||
prompt.replace("{goal}", self.agent.goal)
|
||||
.replace("{role}", self.agent.role)
|
||||
.replace("{backstory}", self.agent.backstory)
|
||||
prompt.replace("{goal}", goal)
|
||||
.replace("{role}", role)
|
||||
.replace("{backstory}", backstory)
|
||||
)
|
||||
return prompt
|
||||
|
||||
74
src/crewai/utilities/token_process.py
Normal file
74
src/crewai/utilities/token_process.py
Normal file
@@ -0,0 +1,74 @@
|
||||
"""Token processing utility for tracking and managing token usage."""
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
class TokenProcess:
|
||||
"""Handles token processing and tracking for agents."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the token processor."""
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
self._successful_requests = 0
|
||||
|
||||
def sum_prompt_tokens(self, count: int) -> None:
|
||||
"""Add to prompt token count.
|
||||
|
||||
Args:
|
||||
count (int): Number of prompt tokens to add
|
||||
"""
|
||||
self._prompt_tokens += count
|
||||
self._total_tokens += count
|
||||
|
||||
def sum_completion_tokens(self, count: int) -> None:
|
||||
"""Add to completion token count.
|
||||
|
||||
Args:
|
||||
count (int): Number of completion tokens to add
|
||||
"""
|
||||
self._completion_tokens += count
|
||||
self._total_tokens += count
|
||||
|
||||
def sum_cached_prompt_tokens(self, count: int) -> None:
|
||||
"""Add to cached prompt token count.
|
||||
|
||||
Args:
|
||||
count (int): Number of cached prompt tokens to add
|
||||
"""
|
||||
self._cached_prompt_tokens += count
|
||||
|
||||
def sum_successful_requests(self, count: int) -> None:
|
||||
"""Add to successful requests count.
|
||||
|
||||
Args:
|
||||
count (int): Number of successful requests to add
|
||||
"""
|
||||
self._successful_requests += count
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset all token counts to zero."""
|
||||
self._total_tokens = 0
|
||||
self._prompt_tokens = 0
|
||||
self._completion_tokens = 0
|
||||
self._cached_prompt_tokens = 0
|
||||
self._successful_requests = 0
|
||||
|
||||
def get_summary(self) -> UsageMetrics:
|
||||
"""Get a summary of token usage.
|
||||
|
||||
Returns:
|
||||
UsageMetrics: Object containing token usage metrics
|
||||
"""
|
||||
return UsageMetrics(
|
||||
total_tokens=self._total_tokens,
|
||||
prompt_tokens=self._prompt_tokens,
|
||||
cached_prompt_tokens=self._cached_prompt_tokens,
|
||||
completion_tokens=self._completion_tokens,
|
||||
successful_requests=self._successful_requests
|
||||
)
|
||||
@@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools():
|
||||
crew.kickoff()
|
||||
|
||||
# Verify task tools override agent tools
|
||||
assert len(task.tools) == 1 # AnotherTestTool
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in task.tools)
|
||||
assert not any(isinstance(tool, TestTool) for tool in task.tools)
|
||||
tools = task.tools or []
|
||||
assert len(tools) == 1 # AnotherTestTool
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in tools)
|
||||
assert not any(isinstance(tool, TestTool) for tool in tools)
|
||||
|
||||
# Verify agent tools remain unchanged
|
||||
assert len(new_researcher.tools) == 1
|
||||
|
||||
6
tests/pytest.ini
Normal file
6
tests/pytest.ini
Normal file
@@ -0,0 +1,6 @@
|
||||
[pytest]
|
||||
markers =
|
||||
vcr: Mark a test as using VCR.py for recording/replaying HTTP interactions
|
||||
|
||||
[vcr]
|
||||
record_mode = none
|
||||
Reference in New Issue
Block a user