Compare commits

...

12 Commits

Author SHA1 Message Date
Devin AI
bfb578d506 fix: Add proper null checks for logger calls and improve type safety in LLM class
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 21:54:49 +00:00
Devin AI
5d3c34b3ea fix: Improve type annotations across multiple files
- Replace Optional[set[str]] with Union[set[str], None] in json methods
- Fix add_nodes_to_network call parameters in flow_visualizer
- Add __base__=BaseModel to create_model call in structured_tool
- Clean up imports in provider.py

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 21:29:15 +00:00
Devin AI
8ec2eb7d72 fix(agent): improve token tracking and logging functionality
- Add proper debug, info, warning, and error methods to Logger class
- Ensure warnings and errors are always shown regardless of verbose mode
- Fix token process initialization and tracking in Agent class
- Update TokenProcess import to use correct class from agent_builder utilities

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-01-01 20:59:39 +00:00
Devin AI
39bdc7e4d4 fix: Improve type annotations and add proper None checks
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-31 23:24:01 +00:00
Devin AI
344fa9bbe5 Merge remote-tracking branch 'origin/pr-1833' into pr-1833
- Integrate latest changes from remote
- Keep LLM parameter handling improvements
- Maintain test fixes and token process utility

Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-31 23:13:28 +00:00
Devin AI
0fd0b5c74f fix: Improve LLM parameter handling and fix test timeouts
- Add proper model name extraction in LLM class
- Handle optional parameters correctly in litellm calls
- Fix Agent constructor compatibility with BaseAgent
- Add token process utility for better tracking
- Clean up parameter handling in LLM class

Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-31 23:12:13 +00:00
João Moura
090d5128cb Merge branch 'main' into pr-1833 2024-12-31 18:41:16 -03:00
Devin AI
dec255e87a fix: Improve type conversion for LLM parameters and handle None values properly
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-31 20:45:25 +00:00
Devin AI
f75b07ce82 Add pytest.ini to force VCR to use recorded cassettes in CI
Co-Authored-By: Joe Moura <joao@crewai.com>
2024-12-31 20:36:02 +00:00
Brandon Hancock
cf2f21cbfb Fix failling ollama tasks 2024-12-31 12:01:03 -05:00
Brandon Hancock
c4a401b247 change litellm version 2024-12-31 11:28:29 -05:00
Brandon Hancock
b0d545992a Suppressed userWarnings from litellm pydantic issues 2024-12-31 11:19:48 -05:00
18 changed files with 917 additions and 186 deletions

View File

@@ -1,3 +1,5 @@
from __future__ import annotations
import os
import shutil
import subprocess
@@ -21,6 +23,9 @@ from crewai.tools.base_tool import Tool
from crewai.utilities import Converter, Prompts
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import generate_model_description
from crewai.utilities.logger import Logger
from crewai.utilities.rpm_controller import RPMController
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.utilities.token_counter_callback import TokenCalcHandler
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -45,24 +50,113 @@ class Agent(BaseAgent):
Each agent has a role, a goal, a backstory, and an optional language model (llm).
The agent can also have memory, can operate in verbose mode, and can delegate tasks to other agents.
Attributes:
agent_executor: An instance of the CrewAgentExecutor class.
role: The role of the agent.
goal: The objective of the agent.
backstory: The backstory of the agent.
knowledge: The knowledge base of the agent.
config: Dict representation of agent configuration.
llm: The language model that will run the agent.
function_calling_llm: The language model that will handle the tool calling for this agent, it overrides the crew function_calling_llm.
max_iter: Maximum number of iterations for an agent to execute a task.
memory: Whether the agent should have memory or not.
max_rpm: Maximum number of requests per minute for the agent execution to be respected.
verbose: Whether the agent execution should be in verbose mode.
allow_delegation: Whether the agent is allowed to delegate tasks to other agents.
tools: Tools at agents disposal
step_callback: Callback to be executed after each step of the agent execution.
knowledge_sources: Knowledge sources for the agent.
Args:
role (Optional[str]): The role of the agent
goal (Optional[str]): The objective of the agent
backstory (Optional[str]): The backstory of the agent
allow_delegation (bool): Whether the agent can delegate tasks
config (Optional[Dict[str, Any]]): Configuration for the agent
verbose (bool): Whether to enable verbose output
max_rpm (Optional[int]): Maximum requests per minute
tools (Optional[List[Any]]): Tools available to the agent
llm (Optional[Union[str, Any]]): Language model to use
function_calling_llm (Optional[Any]): Language model for tool calling
max_iter (Optional[int]): Maximum iterations for task execution
memory (bool): Whether the agent should have memory
step_callback (Optional[Any]): Callback after each execution step
knowledge_sources (Optional[List[BaseKnowledgeSource]]): Knowledge sources
"""
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow",
}
def __init__(
self,
role: Optional[str] = None,
goal: Optional[str] = None,
backstory: Optional[str] = None,
allow_delegation: bool = False,
config: Optional[Dict[str, Any]] = None,
verbose: bool = False,
max_rpm: Optional[int] = None,
tools: Optional[List[Any]] = None,
llm: Optional[Union[str, LLM, Any]] = None,
function_calling_llm: Optional[Any] = None,
max_iter: Optional[int] = None,
memory: bool = True,
step_callback: Optional[Any] = None,
knowledge_sources: Optional[List[BaseKnowledgeSource]] = None,
**kwargs
) -> None:
"""Initialize an Agent with the given parameters."""
# Process tools before passing to parent
processed_tools = []
if tools:
from crewai.tools import BaseTool
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif callable(tool):
# Convert function to BaseTool
processed_tools.append(Tool.from_function(tool))
else:
raise ValueError(f"Tool {tool} must be either a BaseTool instance or a callable")
# Process LLM before passing to parent
processed_llm = None
if isinstance(llm, str):
processed_llm = LLM(model=llm)
elif isinstance(llm, LLM):
processed_llm = llm
elif llm is not None and hasattr(llm, 'model') and hasattr(llm, 'temperature'):
# Handle ChatOpenAI and similar objects
model_name = getattr(llm, 'model', None)
if model_name is not None:
if not isinstance(model_name, str):
model_name = str(model_name)
processed_llm = LLM(
model=model_name,
temperature=getattr(llm, 'temperature', None),
api_key=getattr(llm, 'api_key', None),
base_url=getattr(llm, 'base_url', None)
)
# If no valid LLM configuration found, leave as None for post_init_setup
# Initialize all fields in a dict
init_dict = {
"role": role,
"goal": goal,
"backstory": backstory,
"allow_delegation": allow_delegation,
"config": config,
"verbose": verbose,
"max_rpm": max_rpm,
"tools": processed_tools,
"max_iter": max_iter if max_iter is not None else 25,
"function_calling_llm": function_calling_llm,
"step_callback": step_callback,
"knowledge_sources": knowledge_sources,
**kwargs
}
# Initialize base model with all fields
super().__init__(**init_dict)
# Store original values for interpolation
self._original_role = role
self._original_goal = goal
self._original_backstory = backstory
# Set LLM after base initialization to ensure proper model handling
self.llm = processed_llm
# Initialize private attributes
self._logger = Logger(verbose=self.verbose)
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
self._token_process = TokenProcess()
_times_executed: int = PrivateAttr(default=0)
max_execution_time: Optional[int] = Field(
@@ -138,21 +232,15 @@ class Agent(BaseAgent):
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
self.agent_ops_agent_name = self.role or "agent"
unaccepted_attributes = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
# Handle different cases for self.llm
if isinstance(self.llm, str):
# If it's a string, create an LLM instance
self.llm = LLM(model=self.llm)
elif isinstance(self.llm, LLM):
# If it's already an LLM instance, keep it as is
pass
elif self.llm is None:
# Handle LLM initialization if not already done
if self.llm is None:
# Determine the model name from environment variables or use default
model_name = (
os.environ.get("OPENAI_MODEL_NAME")
@@ -190,9 +278,71 @@ class Agent(BaseAgent):
if key not in ["prompt", "key_name", "default"]:
# Only add default if the key is already set in os.environ
if key in os.environ:
llm_params[key] = value
try:
# Create a new dictionary for properly typed parameters
typed_params = {}
# Convert and validate values based on parameter type
if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
if value is not None:
try:
typed_params[key] = float(value)
except (ValueError, TypeError):
pass
elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']:
if value is not None:
try:
typed_params[key] = int(value)
except (ValueError, TypeError):
pass
elif key == 'logit_bias' and isinstance(value, str):
try:
bias_dict = {}
for pair in value.split(','):
token_id, bias = pair.split(':')
bias_dict[int(token_id.strip())] = float(bias.strip())
typed_params[key] = bias_dict
except (ValueError, AttributeError):
pass
elif key == 'response_format' and isinstance(value, str):
try:
import json
typed_params[key] = json.loads(value)
except json.JSONDecodeError:
pass
elif key == 'logprobs':
if value is not None:
typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value)
elif key == 'callbacks':
typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value
elif key == 'stop':
typed_params[key] = [value] if isinstance(value, str) else value
elif key in ['model', 'base_url', 'api_version', 'api_key']:
typed_params[key] = value
# Update llm_params with properly typed values
if typed_params:
llm_params.update(typed_params)
except (ValueError, AttributeError, json.JSONDecodeError):
continue
self.llm = LLM(**llm_params)
# Create LLM instance with properly typed parameters
valid_params = {
'model', 'timeout', 'temperature', 'top_p', 'n', 'stop',
'max_completion_tokens', 'max_tokens', 'presence_penalty',
'frequency_penalty', 'logit_bias', 'response_format',
'seed', 'logprobs', 'top_logprobs', 'base_url',
'api_version', 'api_key', 'callbacks'
}
# Filter out None values and invalid parameters
filtered_params = {}
for k, v in llm_params.items():
if k in valid_params and v is not None:
filtered_params[k] = v
# Create LLM instance with properly typed parameters
self.llm = LLM(**filtered_params)
else:
# For any other type, attempt to extract relevant attributes
llm_params = {
@@ -239,7 +389,7 @@ class Agent(BaseAgent):
def _set_knowledge(self):
try:
if self.knowledge_sources:
knowledge_agent_name = f"{self.role.replace(' ', '_')}"
knowledge_agent_name = f"{(self.role or 'agent').replace(' ', '_')}"
if isinstance(self.knowledge_sources, list) and all(
isinstance(k, BaseKnowledgeSource) for k in self.knowledge_sources
):
@@ -384,6 +534,32 @@ class Agent(BaseAgent):
self.response_template.split("{{ .Response }}")[1].strip()
)
# Ensure LLM is initialized with proper error handling
try:
if not self.llm:
self.llm = LLM(model="gpt-4")
if hasattr(self, '_logger'):
self._logger.debug("Initialized default LLM with gpt-4 model")
except Exception as e:
if hasattr(self, '_logger'):
self._logger.error(f"Failed to initialize LLM: {str(e)}")
raise
# Create token callback with proper error handling
try:
token_callback = None
if hasattr(self, '_token_process'):
token_callback = TokenCalcHandler(self._token_process)
except Exception as e:
if hasattr(self, '_logger'):
self._logger.warning(f"Failed to create token callback: {str(e)}")
token_callback = None
# Initialize callbacks list
executor_callbacks = []
if token_callback:
executor_callbacks.append(token_callback)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
task=task,
@@ -401,9 +577,9 @@ class Agent(BaseAgent):
function_calling_llm=self.function_calling_llm,
respect_context_window=self.respect_context_window,
request_within_rpm_limit=(
self._rpm_controller.check_or_wait if self._rpm_controller else None
self._rpm_controller.check_or_wait if (hasattr(self, '_rpm_controller') and self._rpm_controller is not None) else None
),
callbacks=[TokenCalcHandler(self._token_process)],
callbacks=executor_callbacks,
)
def get_delegation_tools(self, agents: List[BaseAgent]):

View File

@@ -18,6 +18,7 @@ from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.tools_handler import ToolsHandler
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.tools import BaseTool
from crewai.tools.base_tool import Tool
from crewai.utilities import I18N, Logger, RPMController
@@ -87,9 +88,9 @@ class BaseAgent(ABC, BaseModel):
formatting_errors: int = Field(
default=0, description="Number of formatting errors."
)
role: str = Field(description="Role of the agent")
goal: str = Field(description="Objective of the agent")
backstory: str = Field(description="Backstory of the agent")
role: Optional[str] = Field(default=None, description="Role of the agent")
goal: Optional[str] = Field(default=None, description="Objective of the agent")
backstory: Optional[str] = Field(default=None, description="Backstory of the agent")
config: Optional[Dict[str, Any]] = Field(
description="Configuration for the agent", default=None, exclude=True
)
@@ -130,26 +131,47 @@ class BaseAgent(ABC, BaseModel):
max_tokens: Optional[int] = Field(
default=None, description="Maximum number of tokens for the agent's execution."
)
function_calling_llm: Optional[Any] = Field(
default=None, description="Language model for function calling."
)
step_callback: Optional[Any] = Field(
default=None, description="Callback for execution steps."
)
knowledge_sources: Optional[List[BaseKnowledgeSource]] = Field(
default=None, description="Knowledge sources for the agent."
)
model_config = {
"arbitrary_types_allowed": True,
"extra": "allow", # Allow extra fields in constructor
}
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values):
"""Process configuration values before model initialization."""
return process_config(values, cls)
@field_validator("tools")
@classmethod
def validate_tools(cls, tools: List[Any]) -> List[BaseTool]:
def validate_tools(cls, tools: Optional[List[Any]]) -> List[BaseTool]:
"""Validate and process the tools provided to the agent.
This method ensures that each tool is either an instance of BaseTool
or an object with 'name', 'func', and 'description' attributes. If the
tool meets these criteria, it is processed and added to the list of
tools. Otherwise, a ValueError is raised.
This method ensures that each tool is either an instance of BaseTool,
a function decorated with @tool, or an object with 'name', 'func',
and 'description' attributes. If the tool meets these criteria, it is
processed and added to the list of tools. Otherwise, a ValueError is raised.
"""
if not tools:
return []
processed_tools = []
for tool in tools:
if isinstance(tool, BaseTool):
processed_tools.append(tool)
elif callable(tool) and hasattr(tool, "_is_tool") and tool._is_tool:
# Handle @tool decorated functions
processed_tools.append(Tool.from_function(tool))
elif (
hasattr(tool, "name")
and hasattr(tool, "func")
@@ -157,31 +179,54 @@ class BaseAgent(ABC, BaseModel):
):
# Tool has the required attributes, create a Tool instance
processed_tools.append(Tool.from_langchain(tool))
else:
else:
raise ValueError(
f"Invalid tool type: {type(tool)}. "
"Tool must be an instance of BaseTool or "
"an object with 'name', 'func', and 'description' attributes."
"Tool must be an instance of BaseTool, a @tool decorated function, "
"or an object with 'name', 'func', and 'description' attributes."
)
return processed_tools
@model_validator(mode="after")
def validate_and_set_attributes(self):
# Validate required fields
for field in ["role", "goal", "backstory"]:
if getattr(self, field) is None:
raise ValueError(
f"{field} must be provided either directly or through config"
)
"""Validate and set attributes for the agent.
This method ensures that attributes are properly set and initialized,
either from direct parameters or configuration.
"""
# Store original values for interpolation
self._original_role = self.role
self._original_goal = self.goal
self._original_backstory = self.backstory
# Set private attributes
# Process config if provided
if self.config:
config_data = self.config
if isinstance(config_data, str):
import json
try:
config_data = json.loads(config_data)
except json.JSONDecodeError:
raise ValueError("Invalid JSON in config")
# Update fields from config if they're None
for field in ["role", "goal", "backstory"]:
if field in config_data and getattr(self, field) is None:
setattr(self, field, config_data[field])
# Set default values for required fields if they're still None
self.role = self.role or "Assistant"
self.goal = self.goal or "Help the user accomplish their tasks"
self.backstory = self.backstory or "I am an AI assistant ready to help"
# Initialize tools handler if not set
if not hasattr(self, 'tools_handler') or self.tools_handler is None:
self.tools_handler = ToolsHandler()
# Initialize logger and rpm controller
self._logger = Logger(verbose=self.verbose)
if self.max_rpm and not self._rpm_controller:
self._rpm_controller = RPMController(
max_rpm=self.max_rpm, logger=self._logger
)
if not self._token_process:
self._token_process = TokenProcess()
if self.max_rpm:
self._rpm_controller = RPMController(max_rpm=self.max_rpm, logger=self._logger)
return self
@@ -208,9 +253,9 @@ class BaseAgent(ABC, BaseModel):
@property
def key(self):
source = [
self._original_role or self.role,
self._original_goal or self.goal,
self._original_backstory or self.backstory,
str(self._original_role or self.role or ""),
str(self._original_goal or self.goal or ""),
str(self._original_backstory or self.backstory or ""),
]
return md5("|".join(source).encode(), usedforsecurity=False).hexdigest()
@@ -256,29 +301,45 @@ class BaseAgent(ABC, BaseModel):
"tools_handler",
"cache_handler",
"llm",
"function_calling_llm",
}
# Copy llm and clear callbacks
existing_llm = shallow_copy(self.llm)
# Copy LLMs and clear callbacks
existing_llm = shallow_copy(self.llm) if self.llm else None
existing_function_calling_llm = shallow_copy(self.function_calling_llm) if self.function_calling_llm else None
# Create base data
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
copied_agent = type(self)(**copied_data, llm=existing_llm, tools=self.tools)
# Create new instance with copied data
copied_agent = type(self)(
**copied_data,
llm=existing_llm,
function_calling_llm=existing_function_calling_llm,
tools=self.tools
)
# Copy private attributes
copied_agent._original_role = self._original_role
copied_agent._original_goal = self._original_goal
copied_agent._original_backstory = self._original_backstory
return copied_agent
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
"""Interpolate inputs into the agent description and backstory."""
if self._original_role is None:
self._original_role = self.role
self._original_role = self.role or ""
if self._original_goal is None:
self._original_goal = self.goal
self._original_goal = self.goal or ""
if self._original_backstory is None:
self._original_backstory = self.backstory
self._original_backstory = self.backstory or ""
if inputs:
self.role = self._original_role.format(**inputs)
self.goal = self._original_goal.format(**inputs)
self.backstory = self._original_backstory.format(**inputs)
self.role = self._original_role.format(**inputs) if self._original_role else None
self.goal = self._original_goal.format(**inputs) if self._original_goal else None
self.backstory = self._original_backstory.format(**inputs) if self._original_backstory else None
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
"""Set the cache handler for the agent.

View File

@@ -82,16 +82,17 @@ class CrewAgentExecutorMixin:
)
self.crew._long_term_memory.save(long_term_memory)
for entity in evaluation.entities:
entity_memory = EntityMemoryItem(
name=entity.name,
type=entity.type,
description=entity.description,
relationships="\n".join(
[f"- {r}" for r in entity.relationships]
),
)
self.crew._entity_memory.save(entity_memory)
if hasattr(evaluation, 'entities') and evaluation.entities:
for entity in evaluation.entities:
entity_memory = EntityMemoryItem(
name=entity.name,
type=entity.type,
description=entity.description,
relationships="\n".join(
[f"- {r}" for r in entity.relationships]
),
)
self.crew._entity_memory.save(entity_memory)
except AttributeError as e:
print(f"Missing attributes for long term memory: {e}")
pass

View File

@@ -68,7 +68,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.tools_handler = tools_handler
self.original_tools = original_tools
self.step_callback = step_callback
self.use_stop_words = self.llm.supports_stop_words()
self.use_stop_words = self.llm.supports_stop_words() if self.llm else False
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
@@ -147,7 +147,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
# Directly append the result to the messages if the
# tool is "Add image to content" in case of multimodal
# agents
if formatted_answer.tool == self._i18n.tools("add_image")["name"]:
add_image_tool_name = self._i18n.tools("add_image")
if add_image_tool_name and formatted_answer.tool == add_image_tool_name:
self.messages.append(tool_result.result)
continue
@@ -214,13 +215,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent.verbose or (
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
):
agent_role = self.agent.role.split("\n")[0]
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
self._printer.print(
content=f"\033[1m\033[95m# Agent:\033[00m \033[1m\033[92m{agent_role}\033[00m"
)
self._printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
)
if self.task and self.task.description:
self._printer.print(
content=f"\033[95m## Task:\033[00m \033[92m{self.task.description}\033[00m"
)
def _show_logs(self, formatted_answer: Union[AgentAction, AgentFinish]):
if self.agent is None:
@@ -228,7 +230,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
if self.agent.verbose or (
hasattr(self, "crew") and getattr(self.crew, "verbose", False)
):
agent_role = self.agent.role.split("\n")[0]
agent_role = self.agent.role.split("\n")[0] if self.agent and self.agent.role else ""
if isinstance(formatted_answer, AgentAction):
thought = re.sub(r"\n+", "\n", formatted_answer.thought)
formatted_json = json.dumps(

View File

@@ -5,6 +5,7 @@ from pathlib import Path
import click
import requests
from typing import Any
from crewai.cli.constants import JSON_URL, MODELS, PROVIDERS
@@ -192,7 +193,7 @@ def download_data(response):
data_chunks = []
with click.progressbar(
length=total_size, label="Downloading", show_pos=True
) as progress_bar:
) as progress_bar: # type: Any
for chunk in response.iter_content(block_size):
if chunk:
data_chunks.append(chunk)

View File

@@ -6,6 +6,8 @@ from concurrent.futures import Future
from hashlib import md5
from typing import Any, Callable, Dict, List, Optional, Tuple, Union
from crewai.tools.base_tool import BaseTool
from pydantic import (
UUID4,
BaseModel,
@@ -728,7 +730,7 @@ class Crew(BaseModel):
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task)
self._log_task_start(task, agent_to_use.role)
self._log_task_start(task, agent_to_use.role if agent_to_use and agent_to_use.role else "")
if isinstance(task, ConditionalTask):
skipped_task_output = self._handle_conditional_task(
@@ -794,8 +796,8 @@ class Crew(BaseModel):
return None
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: List[Tool]
) -> List[Tool]:
self, agent: BaseAgent, task: Task, tools: List[Union[Tool, BaseTool]]
) -> List[Union[Tool, BaseTool]]:
# Add delegation tools if agent allows delegation
if agent.allow_delegation:
if self.process == Process.hierarchical:
@@ -824,8 +826,8 @@ class Crew(BaseModel):
return task.agent
def _merge_tools(
self, existing_tools: List[Tool], new_tools: List[Tool]
) -> List[Tool]:
self, existing_tools: List[Union[Tool, BaseTool]], new_tools: List[Union[Tool, BaseTool]]
) -> List[Union[Tool, BaseTool]]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
return existing_tools

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from pydantic import BaseModel, Field
@@ -23,14 +23,25 @@ class CrewOutput(BaseModel):
)
token_usage: UsageMetrics = Field(description="Processed token summary", default={})
@property
def json(self) -> Optional[str]:
def json(
self,
*,
include: Union[set[str], None] = None,
exclude: Union[set[str], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True,
**dumps_kwargs: Any,
) -> str:
if self.tasks_output[-1].output_format != OutputFormat.JSON:
raise ValueError(
"No JSON output found in the final task. Please make sure to set the output_json property in the final task in your crew."
)
return json.dumps(self.json_dict)
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -106,7 +106,12 @@ class FlowPlot:
# Add nodes to the network
try:
add_nodes_to_network(net, self.flow, node_positions, self.node_styles)
add_nodes_to_network(
net,
flow=self.flow,
pos=node_positions,
node_styles=self.node_styles
)
except Exception as e:
raise RuntimeError(f"Failed to add nodes to network: {str(e)}")

View File

@@ -6,6 +6,8 @@ import warnings
from contextlib import contextmanager
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
@@ -93,10 +95,33 @@ def suppress_warnings():
sys.stderr = old_stderr
class LLM:
class LLM(BaseModel):
model: str = "gpt-4" # Set default model
timeout: Optional[Union[float, int]] = None
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
max_completion_tokens: Optional[int] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
logprobs: Optional[bool] = None
top_logprobs: Optional[int] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
callbacks: Optional[List[Any]] = None
context_window_size: Optional[int] = None
kwargs: Dict[str, Any] = Field(default_factory=dict)
logger: Optional[logging.Logger] = Field(default_factory=lambda: logging.getLogger(__name__))
def __init__(
self,
model: str,
model: Optional[Union[str, 'LLM']] = "gpt-4",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
@@ -114,118 +139,427 @@ class LLM:
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
**kwargs,
):
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.stop = stop
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.kwargs = kwargs
callbacks: Optional[List[Any]] = None,
context_window_size: Optional[int] = None,
**kwargs: Any,
) -> None:
# Initialize with default values
init_dict = {
"model": model if isinstance(model, str) else "gpt-4",
"timeout": timeout,
"temperature": temperature,
"top_p": top_p,
"n": n,
"stop": stop,
"max_completion_tokens": max_completion_tokens,
"max_tokens": max_tokens,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"logit_bias": logit_bias,
"response_format": response_format,
"seed": seed,
"logprobs": logprobs,
"top_logprobs": top_logprobs,
"base_url": base_url,
"api_version": api_version,
"api_key": api_key,
"callbacks": callbacks,
"context_window_size": context_window_size,
"kwargs": kwargs,
}
super().__init__(**init_dict)
# Initialize model with default value
self.model = "gpt-4" # Default fallback
# Extract and validate model name
if isinstance(model, LLM):
# Extract and validate model name from LLM instance
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
else:
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# If model is an LLM instance, copy its configuration
if isinstance(model, LLM):
# Extract and validate model name first
if hasattr(model, 'model'):
if isinstance(model.model, str):
self.model = model.model
else:
# Try to extract string model name from nested LLM
if isinstance(model.model, LLM):
self.model = str(model.model.model) if hasattr(model.model, 'model') else "gpt-4"
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Nested LLM model is not a string, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("LLM instance has no model attribute, using default: gpt-4")
# Copy other configuration
self.timeout = model.timeout
self.temperature = model.temperature
self.top_p = model.top_p
self.n = model.n
self.stop = model.stop
self.max_completion_tokens = model.max_completion_tokens
self.max_tokens = model.max_tokens
self.presence_penalty = model.presence_penalty
self.frequency_penalty = model.frequency_penalty
self.logit_bias = model.logit_bias
self.response_format = model.response_format
self.seed = model.seed
self.logprobs = model.logprobs
self.top_logprobs = model.top_logprobs
self.base_url = model.base_url
self.api_version = model.api_version
self.api_key = model.api_key
self.callbacks = model.callbacks
self.context_window_size = model.context_window_size
self.kwargs = model.kwargs
# Final validation of model name
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after LLM copy, using default: gpt-4")
else:
# Extract and validate model name for non-LLM instances
if not isinstance(model, str):
if self.logger:
self.logger.debug(f"Model is not a string, attempting to extract name. Type: {type(model)}")
if model is not None:
if hasattr(model, 'model_name'):
model_name = getattr(model, 'model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
elif hasattr(model, 'model'):
model_attr = getattr(model, 'model', None)
self.model = str(model_attr) if model_attr is not None else "gpt-4"
elif hasattr(model, '_model_name'):
model_name = getattr(model, '_model_name', None)
self.model = str(model_name) if model_name is not None else "gpt-4"
else:
self.model = "gpt-4" # Default fallback
if self.logger:
self.logger.warning(f"Could not extract model name from {type(model)}, using default: {self.model}")
else:
self.model = "gpt-4" # Default fallback for None
if self.logger:
self.logger.warning("Model is None, using default: gpt-4")
else:
self.model = str(model) # Ensure it's a string
# Final validation
if not isinstance(self.model, str):
self.model = "gpt-4"
if self.logger:
self.logger.warning("Model name is still not a string after extraction, using default: gpt-4")
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.stop = stop
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.kwargs = kwargs
# Ensure model is a string after initialization
if not isinstance(self.model, str):
self.model = "gpt-4"
self.logger.warning(f"Model is still not a string after initialization, using default: {self.model}")
litellm.drop_params = True
self.set_callbacks(callbacks)
self.set_env_callbacks()
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
def call(
self,
messages: List[Dict[str, str]],
callbacks: Optional[List[Any]] = None
) -> str:
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
# Store original model to restore later
original_model = self.model
try:
# Ensure model is a string before making the call
if not isinstance(self.model, str):
if self.logger:
self.logger.warning(f"Model is not a string in call method: {type(self.model)}. Attempting to convert...")
if isinstance(self.model, LLM):
self.model = self.model.model if isinstance(self.model.model, str) else "gpt-4"
elif hasattr(self.model, 'model_name'):
self.model = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
self.model = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
self.model = str(self.model.model.model_name)
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name from nested model object, using default: gpt-4")
else:
self.model = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
if self.logger:
self.logger.debug(f"Using model: {self.model} (type: {type(self.model)}) for LiteLLM call")
# Create base params with validated model name
# Extract model name string
model_name = None
if isinstance(self.model, str):
model_name = self.model
elif hasattr(self.model, 'model_name'):
model_name = str(self.model.model_name)
elif hasattr(self.model, 'model'):
if isinstance(self.model.model, str):
model_name = str(self.model.model)
elif hasattr(self.model.model, 'model_name'):
model_name = str(self.model.model.model_name)
if not model_name:
model_name = "gpt-4"
if self.logger:
self.logger.warning("Could not extract model name, using default: gpt-4")
params = {
"model": self.model,
"model": model_name,
"messages": messages,
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs,
"stream": False,
"api_key": self.api_key or os.getenv("OPENAI_API_KEY"),
"api_base": self.base_url,
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
**self.kwargs,
}
if self.logger:
self.logger.debug(f"Using model parameters: {params}")
# Add API configuration if available
api_key = self.api_key or os.getenv("OPENAI_API_KEY")
if api_key:
params["api_key"] = api_key
# Try to get supported parameters for the model
try:
supported_params = get_supported_openai_params(self.model)
optional_params = {}
if supported_params:
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
# Only add parameters that are supported and not None
optional_params = {
k: v for k, v in param_mapping.items()
if k in supported_params and v is not None
}
if "logprobs" in supported_params and self.logprobs is not None:
optional_params["logprobs"] = self.logprobs
if "top_logprobs" in supported_params and self.top_logprobs is not None:
optional_params["top_logprobs"] = self.top_logprobs
except Exception as e:
if self.logger:
self.logger.error(f"Failed to get supported params for model {self.model}: {str(e)}")
# If we can't get supported params, just add non-None parameters
param_mapping = {
"timeout": self.timeout,
"temperature": self.temperature,
"top_p": self.top_p,
"n": self.n,
"stop": self.stop,
"max_tokens": self.max_tokens or self.max_completion_tokens,
"presence_penalty": self.presence_penalty,
"frequency_penalty": self.frequency_penalty,
"logit_bias": self.logit_bias,
"response_format": self.response_format,
"seed": self.seed,
"logprobs": self.logprobs,
"top_logprobs": self.top_logprobs
}
optional_params = {k: v for k, v in param_mapping.items() if v is not None}
# Update params with optional parameters
params.update(optional_params)
# Add API endpoint configuration if available
if self.base_url:
params["api_base"] = self.base_url
if self.api_version:
params["api_version"] = self.api_version
# Final validation of model parameter
if not isinstance(params["model"], str):
if self.logger:
self.logger.error(f"Model is still not a string after all conversions: {type(params['model'])}")
params["model"] = "gpt-4"
# Update params with non-None optional parameters
params.update({k: v for k, v in optional_params.items() if v is not None})
# Add any additional kwargs
if self.kwargs:
params.update(self.kwargs)
# Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
content = response["choices"][0]["message"]["content"]
# Extract usage metrics
usage = response.get("usage", {})
if callbacks:
for callback in callbacks:
if hasattr(callback, "update_token_usage"):
callback.update_token_usage(usage)
return content
except Exception as e:
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
finally:
# Always restore the original model object
self.model = original_model
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
bool: True if the model supports function calling, False otherwise
"""
try:
params = get_supported_openai_params(model=self.model)
return "response_format" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns False if the LLM is not properly initialized."""
if not hasattr(self, 'model') or self.model is None:
return False
try:
params = get_supported_openai_params(model=self.model)
return "stop" in params
except Exception as e:
logging.error(f"Failed to get supported params: {str(e)}")
if self.logger:
self.logger.error(f"Failed to get supported params: {str(e)}")
return False
def get_context_window_size(self) -> int:
"""Get the context window size for the current model.
Returns:
int: The context window size in tokens
"""
# Only using 75% of the context window size to avoid cutting the message in the middle
if self.context_window_size != 0:
return self.context_window_size
if self.context_window_size is not None and self.context_window_size != 0:
return int(self.context_window_size)
self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
)
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
self.context_window_size = int(value * CONTEXT_WINDOW_USAGE_RATIO)
window_size = DEFAULT_CONTEXT_WINDOW_SIZE
if isinstance(self.model, str):
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if self.model.startswith(key):
window_size = value
break
self.context_window_size = int(window_size * CONTEXT_WINDOW_USAGE_RATIO)
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
def set_callbacks(self, callbacks: Optional[List[Any]] = None) -> None:
"""Set callbacks for the LLM.
Args:
callbacks: Optional list of callback functions. If None, no callbacks will be set.
"""
if callbacks is not None:
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
if type(callback) in callback_types:
litellm.success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
for callback in litellm._async_success_callback[:]:
if type(callback) in callback_types:
litellm._async_success_callback.remove(callback)
litellm.callbacks = callbacks
litellm.callbacks = callbacks
def set_env_callbacks(self):
"""

View File

@@ -269,7 +269,9 @@ class Task(BaseModel):
@model_validator(mode="after")
def check_tools(self):
"""Check if the tools are set."""
if not self.tools and self.agent and self.agent.tools:
if self.agent and self.agent.tools:
if self.tools is None:
self.tools = []
self.tools.extend(self.agent.tools)
return self
@@ -348,7 +350,8 @@ class Task(BaseModel):
self.prompt_context = context
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
if agent and agent.role:
self.processed_by_agents.add(agent.role)
result = agent.execute_task(
task=self,

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional
from typing import Any, Callable, Dict, Optional, Union
from pydantic import BaseModel, Field, model_validator
@@ -34,8 +34,19 @@ class TaskOutput(BaseModel):
self.summary = f"{excerpt}..."
return self
@property
def json(self) -> Optional[str]:
def json(
self,
*,
include: Union[set[str], None] = None,
exclude: Union[set[str], None] = None,
by_alias: bool = False,
exclude_unset: bool = False,
exclude_defaults: bool = False,
exclude_none: bool = False,
encoder: Optional[Callable[[Any], Any]] = None,
models_as_dict: bool = True,
**dumps_kwargs: Any,
) -> str:
if self.output_format != OutputFormat.JSON:
raise ValueError(
"""
@@ -45,7 +56,7 @@ class TaskOutput(BaseModel):
"""
)
return json.dumps(self.json_dict)
return json.dumps(self.json_dict, default=encoder, **dumps_kwargs)
def to_dict(self) -> Dict[str, Any]:
"""Convert json_output and pydantic_output to a dictionary."""

View File

@@ -19,13 +19,13 @@ class BaseAgentTool(BaseTool):
default_factory=I18N, description="Internationalization settings"
)
def sanitize_agent_name(self, name: str) -> str:
def sanitize_agent_name(self, name: Optional[str]) -> str:
"""
Sanitize agent role name by normalizing whitespace and setting to lowercase.
Converts all whitespace (including newlines) to single spaces and removes quotes.
Args:
name (str): The agent role name to sanitize
name (Optional[str]): The agent role name to sanitize
Returns:
str: The sanitized agent role name, with whitespace normalized,

View File

@@ -142,7 +142,12 @@ class CrewStructuredTool:
# Create model
schema_name = f"{name.title()}Schema"
return create_model(schema_name, **fields)
return create_model(
schema_name,
__base__=BaseModel,
__config__=None,
**{k: v for k, v in fields.items()}
)
def _validate_function_signature(self) -> None:
"""Validate that the function signature matches the args schema."""
@@ -170,7 +175,7 @@ class CrewStructuredTool:
f"not found in args_schema"
)
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
def _parse_args(self, raw_args: Union[str, dict[str, Any]]) -> dict[str, Any]:
"""Parse and validate the input arguments against the schema.
Args:
@@ -178,6 +183,9 @@ class CrewStructuredTool:
Returns:
The validated arguments as a dictionary
Raises:
ValueError: If the arguments cannot be parsed or fail validation
"""
if isinstance(raw_args, str):
try:
@@ -195,8 +203,8 @@ class CrewStructuredTool:
async def ainvoke(
self,
input: Union[str, dict],
config: Optional[dict] = None,
input: Union[str, dict[str, Any]],
config: Optional[dict[str, Any]] = None,
**kwargs: Any,
) -> Any:
"""Asynchronously invoke the tool.
@@ -229,7 +237,10 @@ class CrewStructuredTool:
return self.invoke(input_dict)
def invoke(
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
self,
input: Union[str, dict[str, Any]],
config: Optional[dict[str, Any]] = None,
**kwargs: Any
) -> Any:
"""Main method for tool execution."""
parsed_args = self._parse_args(input)

View File

@@ -10,8 +10,24 @@ class Logger(BaseModel):
_printer: Printer = PrivateAttr(default_factory=Printer)
def log(self, level, message, color="bold_yellow"):
if self.verbose:
if self.verbose or level.upper() in ["WARNING", "ERROR"]:
timestamp = datetime.now().strftime("%Y-%m-%d %H:%M:%S")
self._printer.print(
f"\n[{timestamp}][{level.upper()}]: {message}", color=color
)
def debug(self, message: str) -> None:
"""Log a debug message if verbose is enabled."""
self.log("debug", message, color="bold_blue")
def info(self, message: str) -> None:
"""Log an info message if verbose is enabled."""
self.log("info", message, color="bold_green")
def warning(self, message: str) -> None:
"""Log a warning message."""
self.log("warning", message, color="bold_yellow")
def error(self, message: str) -> None:
"""Log an error message."""
self.log("error", message, color="bold_red")

View File

@@ -63,16 +63,32 @@ class Prompts(BaseModel):
for component in components
if component != "task"
]
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
prompt = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
)
response = response_template.split("{{ .Response }}")[0]
prompt = f"{system}\n{prompt}\n{response}"
system = ""
if system_template:
system = system_template.replace("{{ .System }}", "".join(prompt_parts))
prompt_text = ""
if prompt_template:
prompt_text = prompt_template.replace(
"{{ .Prompt }}", "".join(self.i18n.slice("task"))
)
response = ""
if response_template:
response = response_template.split("{{ .Response }}")[0]
parts = [p for p in [system, prompt_text, response] if p]
prompt = "\n".join(parts) if parts else ""
# Get agent attributes with default values
goal = str(getattr(self.agent, 'goal', '') or '')
role = str(getattr(self.agent, 'role', '') or '')
backstory = str(getattr(self.agent, 'backstory', '') or '')
# Replace placeholders with agent attributes
prompt = (
prompt.replace("{goal}", self.agent.goal)
.replace("{role}", self.agent.role)
.replace("{backstory}", self.agent.backstory)
prompt.replace("{goal}", goal)
.replace("{role}", role)
.replace("{backstory}", backstory)
)
return prompt

View File

@@ -0,0 +1,74 @@
"""Token processing utility for tracking and managing token usage."""
from __future__ import annotations
from typing import Any, Dict, List, Optional, Union
from crewai.types.usage_metrics import UsageMetrics
class TokenProcess:
"""Handles token processing and tracking for agents."""
def __init__(self):
"""Initialize the token processor."""
self._total_tokens = 0
self._prompt_tokens = 0
self._completion_tokens = 0
self._cached_prompt_tokens = 0
self._successful_requests = 0
def sum_prompt_tokens(self, count: int) -> None:
"""Add to prompt token count.
Args:
count (int): Number of prompt tokens to add
"""
self._prompt_tokens += count
self._total_tokens += count
def sum_completion_tokens(self, count: int) -> None:
"""Add to completion token count.
Args:
count (int): Number of completion tokens to add
"""
self._completion_tokens += count
self._total_tokens += count
def sum_cached_prompt_tokens(self, count: int) -> None:
"""Add to cached prompt token count.
Args:
count (int): Number of cached prompt tokens to add
"""
self._cached_prompt_tokens += count
def sum_successful_requests(self, count: int) -> None:
"""Add to successful requests count.
Args:
count (int): Number of successful requests to add
"""
self._successful_requests += count
def reset(self) -> None:
"""Reset all token counts to zero."""
self._total_tokens = 0
self._prompt_tokens = 0
self._completion_tokens = 0
self._cached_prompt_tokens = 0
self._successful_requests = 0
def get_summary(self) -> UsageMetrics:
"""Get a summary of token usage.
Returns:
UsageMetrics: Object containing token usage metrics
"""
return UsageMetrics(
total_tokens=self._total_tokens,
prompt_tokens=self._prompt_tokens,
cached_prompt_tokens=self._cached_prompt_tokens,
completion_tokens=self._completion_tokens,
successful_requests=self._successful_requests
)

View File

@@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools():
crew.kickoff()
# Verify task tools override agent tools
assert len(task.tools) == 1 # AnotherTestTool
assert any(isinstance(tool, AnotherTestTool) for tool in task.tools)
assert not any(isinstance(tool, TestTool) for tool in task.tools)
tools = task.tools or []
assert len(tools) == 1 # AnotherTestTool
assert any(isinstance(tool, AnotherTestTool) for tool in tools)
assert not any(isinstance(tool, TestTool) for tool in tools)
# Verify agent tools remain unchanged
assert len(new_researcher.tools) == 1

6
tests/pytest.ini Normal file
View File

@@ -0,0 +1,6 @@
[pytest]
markers =
vcr: Mark a test as using VCR.py for recording/replaying HTTP interactions
[vcr]
record_mode = none