mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Some checks failed
Notify Downstream / notify-downstream (push) Has been cancelled
* feat: implement tool usage limit exception handling - Introduced `ToolUsageLimitExceeded` exception to manage maximum usage limits for tools. - Enhanced `CrewStructuredTool` to check and raise this exception when the usage limit is reached. - Updated `_run` and `_execute` methods to include usage limit checks and handle exceptions appropriately, improving reliability and user feedback. * feat: enhance PlusAPI and ToolUsage with task metadata - Removed the `send_trace_batch` method from PlusAPI to streamline the API. - Added timeout parameters to trace event methods in PlusAPI for improved reliability. - Updated ToolUsage to include task metadata (task name and ID) in event emissions, enhancing traceability and context during tool usage. - Refactored event handling in LLM and ToolUsage events to ensure task information is consistently captured. * feat: enhance memory and event handling with task and agent metadata - Added task and agent metadata to various memory and event classes, improving traceability and context during memory operations. - Updated the `ContextualMemory` and `Memory` classes to associate tasks and agents, allowing for better context management. - Enhanced event emissions in `LLM`, `ToolUsage`, and memory events to include task and agent information, facilitating improved debugging and monitoring. - Refactored event handling to ensure consistent capture of task and agent details across the system. * drop * refactor: clean up unused imports in memory and event modules - Removed unused TYPE_CHECKING imports from long_term_memory.py to streamline the code. - Eliminated unnecessary import from memory_events.py, enhancing clarity and maintainability. * fix memory tests * fix task_completed payload * fix: remove unused test agent variable in external memory tests * refactor: remove unused agent parameter from Memory class save method - Eliminated the agent parameter from the save method in the Memory class to streamline the code and improve clarity. - Updated the TraceBatchManager class by moving initialization of attributes into the constructor for better organization and readability. * refactor: enhance ExecutionState and ReasoningEvent classes with optional task and agent identifiers - Added optional `current_agent_id` and `current_task_id` attributes to the `ExecutionState` class for better tracking of agent and task states. - Updated the `from_task` attribute in the `ReasoningEvent` class to use `Optional[Any]` instead of a specific type, improving flexibility in event handling. * refactor: update ExecutionState class by removing unused agent and task identifiers - Removed the `current_agent_id` and `current_task_id` attributes from the `ExecutionState` class to simplify the code and enhance clarity. - Adjusted the import statements to include `Optional` for better type handling. * refactor: streamline LLM event handling in LiteAgent - Removed unused LLM event emissions (LLMCallStartedEvent, LLMCallCompletedEvent, LLMCallFailedEvent) from the LiteAgent class to simplify the code and improve performance. - Adjusted the flow of LLM response handling by eliminating unnecessary event bus interactions, enhancing clarity and maintainability. * flow ownership and not emitting events when a crew is done * refactor: remove unused agent parameter from ShortTermMemory save method - Eliminated the agent parameter from the save method in the ShortTermMemory class to streamline the code and improve clarity. - This change enhances the maintainability of the memory management system by reducing unnecessary complexity. * runtype check fix * fixing tests * fix lints * fix: update event assertions in test_llm_emits_event_with_lite_agent - Adjusted the expected counts for completed and started events in the test to reflect the correct behavior of the LiteAgent. - Updated assertions for agent roles and IDs to match the expected values after recent changes in event handling. * fix: update task name assertions in event tests - Modified assertions in `test_stream_llm_emits_event_with_task_and_agent_info` and `test_llm_emits_event_with_task_and_agent_info` to use `task.description` as a fallback for `task.name`. This ensures that the tests correctly validate the task name even when it is not explicitly set. * fix: update test assertions for output values and improve readability - Updated assertions in `test_output_json_dict_hierarchical` to reflect the correct expected score value. - Enhanced readability of assertions in `test_output_pydantic_to_another_task` and `test_key` by formatting the error messages for clarity. - These changes ensure that the tests accurately validate the expected outputs and improve overall code quality. * test fixes * fix crew_test * added another fixture * fix: ensure agent and task assignments in contextual memory are conditional - Updated the ContextualMemory class to check for the existence of short-term, long-term, external, and extended memory before assigning agent and task attributes. This prevents potential attribute errors when memory types are not initialized.
314 lines
10 KiB
Python
314 lines
10 KiB
Python
from __future__ import annotations
|
|
|
|
import asyncio
|
|
|
|
import inspect
|
|
import textwrap
|
|
from typing import Any, Callable, Optional, Union, get_type_hints
|
|
|
|
from pydantic import BaseModel, Field, create_model
|
|
|
|
from crewai.utilities.logger import Logger
|
|
|
|
from typing import TYPE_CHECKING
|
|
|
|
if TYPE_CHECKING:
|
|
from crewai.tools.base_tool import BaseTool
|
|
|
|
|
|
class ToolUsageLimitExceeded(Exception):
|
|
"""Exception raised when a tool has reached its maximum usage limit."""
|
|
|
|
pass
|
|
|
|
|
|
class CrewStructuredTool:
|
|
"""A structured tool that can operate on any number of inputs.
|
|
|
|
This tool intends to replace StructuredTool with a custom implementation
|
|
that integrates better with CrewAI's ecosystem.
|
|
"""
|
|
|
|
_original_tool: BaseTool | None = None
|
|
|
|
def __init__(
|
|
self,
|
|
name: str,
|
|
description: str,
|
|
args_schema: type[BaseModel],
|
|
func: Callable[..., Any],
|
|
result_as_answer: bool = False,
|
|
max_usage_count: int | None = None,
|
|
current_usage_count: int = 0,
|
|
) -> None:
|
|
"""Initialize the structured tool.
|
|
|
|
Args:
|
|
name: The name of the tool
|
|
description: A description of what the tool does
|
|
args_schema: The pydantic model for the tool's arguments
|
|
func: The function to run when the tool is called
|
|
result_as_answer: Whether to return the output directly
|
|
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
|
current_usage_count: Current number of times this tool has been used.
|
|
"""
|
|
self.name = name
|
|
self.description = description
|
|
self.args_schema = args_schema
|
|
self.func = func
|
|
self._logger = Logger()
|
|
self.result_as_answer = result_as_answer
|
|
self.max_usage_count = max_usage_count
|
|
self.current_usage_count = current_usage_count
|
|
self._original_tool = None
|
|
|
|
# Validate the function signature matches the schema
|
|
self._validate_function_signature()
|
|
|
|
@classmethod
|
|
def from_function(
|
|
cls,
|
|
func: Callable,
|
|
name: Optional[str] = None,
|
|
description: Optional[str] = None,
|
|
return_direct: bool = False,
|
|
args_schema: Optional[type[BaseModel]] = None,
|
|
infer_schema: bool = True,
|
|
**kwargs: Any,
|
|
) -> CrewStructuredTool:
|
|
"""Create a tool from a function.
|
|
|
|
Args:
|
|
func: The function to create a tool from
|
|
name: The name of the tool. Defaults to the function name
|
|
description: The description of the tool. Defaults to the function docstring
|
|
return_direct: Whether to return the output directly
|
|
args_schema: Optional schema for the function arguments
|
|
infer_schema: Whether to infer the schema from the function signature
|
|
**kwargs: Additional arguments to pass to the tool
|
|
|
|
Returns:
|
|
A CrewStructuredTool instance
|
|
|
|
Example:
|
|
>>> def add(a: int, b: int) -> int:
|
|
... '''Add two numbers'''
|
|
... return a + b
|
|
>>> tool = CrewStructuredTool.from_function(add)
|
|
"""
|
|
name = name or func.__name__
|
|
description = description or inspect.getdoc(func)
|
|
|
|
if description is None:
|
|
raise ValueError(
|
|
f"Function {name} must have a docstring if description not provided."
|
|
)
|
|
|
|
# Clean up the description
|
|
description = textwrap.dedent(description).strip()
|
|
|
|
if args_schema is not None:
|
|
# Use provided schema
|
|
schema = args_schema
|
|
elif infer_schema:
|
|
# Infer schema from function signature
|
|
schema = cls._create_schema_from_function(name, func)
|
|
else:
|
|
raise ValueError(
|
|
"Either args_schema must be provided or infer_schema must be True."
|
|
)
|
|
|
|
return cls(
|
|
name=name,
|
|
description=description,
|
|
args_schema=schema,
|
|
func=func,
|
|
result_as_answer=return_direct,
|
|
)
|
|
|
|
@staticmethod
|
|
def _create_schema_from_function(
|
|
name: str,
|
|
func: Callable,
|
|
) -> type[BaseModel]:
|
|
"""Create a Pydantic schema from a function's signature.
|
|
|
|
Args:
|
|
name: The name to use for the schema
|
|
func: The function to create a schema from
|
|
|
|
Returns:
|
|
A Pydantic model class
|
|
"""
|
|
# Get function signature
|
|
sig = inspect.signature(func)
|
|
|
|
# Get type hints
|
|
type_hints = get_type_hints(func)
|
|
|
|
# Create field definitions
|
|
fields = {}
|
|
for param_name, param in sig.parameters.items():
|
|
# Skip self/cls for methods
|
|
if param_name in ("self", "cls"):
|
|
continue
|
|
|
|
# Get type annotation
|
|
annotation = type_hints.get(param_name, Any)
|
|
|
|
# Get default value
|
|
default = ... if param.default == param.empty else param.default
|
|
|
|
# Add field
|
|
fields[param_name] = (annotation, Field(default=default))
|
|
|
|
# Create model
|
|
schema_name = f"{name.title()}Schema"
|
|
return create_model(schema_name, **fields)
|
|
|
|
def _validate_function_signature(self) -> None:
|
|
"""Validate that the function signature matches the args schema."""
|
|
sig = inspect.signature(self.func)
|
|
schema_fields = self.args_schema.model_fields
|
|
|
|
# Check required parameters
|
|
for param_name, param in sig.parameters.items():
|
|
# Skip self/cls for methods
|
|
if param_name in ("self", "cls"):
|
|
continue
|
|
|
|
# Skip **kwargs parameters
|
|
if param.kind in (
|
|
inspect.Parameter.VAR_KEYWORD,
|
|
inspect.Parameter.VAR_POSITIONAL,
|
|
):
|
|
continue
|
|
|
|
# Only validate required parameters without defaults
|
|
if param.default == inspect.Parameter.empty:
|
|
if param_name not in schema_fields:
|
|
raise ValueError(
|
|
f"Required function parameter '{param_name}' "
|
|
f"not found in args_schema"
|
|
)
|
|
|
|
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
|
"""Parse and validate the input arguments against the schema.
|
|
|
|
Args:
|
|
raw_args: The raw arguments to parse, either as a string or dict
|
|
|
|
Returns:
|
|
The validated arguments as a dictionary
|
|
"""
|
|
if isinstance(raw_args, str):
|
|
try:
|
|
import json
|
|
|
|
raw_args = json.loads(raw_args)
|
|
except json.JSONDecodeError as e:
|
|
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
|
|
|
try:
|
|
validated_args = self.args_schema.model_validate(raw_args)
|
|
return validated_args.model_dump()
|
|
except Exception as e:
|
|
raise ValueError(f"Arguments validation failed: {e}")
|
|
|
|
async def ainvoke(
|
|
self,
|
|
input: Union[str, dict],
|
|
config: Optional[dict] = None,
|
|
**kwargs: Any,
|
|
) -> Any:
|
|
"""Asynchronously invoke the tool.
|
|
|
|
Args:
|
|
input: The input arguments
|
|
config: Optional configuration
|
|
**kwargs: Additional keyword arguments
|
|
|
|
Returns:
|
|
The result of the tool execution
|
|
"""
|
|
parsed_args = self._parse_args(input)
|
|
|
|
if self.has_reached_max_usage_count():
|
|
raise ToolUsageLimitExceeded(
|
|
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
|
)
|
|
|
|
self._increment_usage_count()
|
|
|
|
try:
|
|
if inspect.iscoroutinefunction(self.func):
|
|
return await self.func(**parsed_args, **kwargs)
|
|
else:
|
|
# Run sync functions in a thread pool
|
|
import asyncio
|
|
|
|
return await asyncio.get_event_loop().run_in_executor(
|
|
None, lambda: self.func(**parsed_args, **kwargs)
|
|
)
|
|
except Exception:
|
|
raise
|
|
|
|
def _run(self, *args, **kwargs) -> Any:
|
|
"""Legacy method for compatibility."""
|
|
# Convert args/kwargs to our expected format
|
|
input_dict = dict(zip(self.args_schema.model_fields.keys(), args))
|
|
input_dict.update(kwargs)
|
|
return self.invoke(input_dict)
|
|
|
|
def invoke(
|
|
self, input: Union[str, dict], config: Optional[dict] = None, **kwargs: Any
|
|
) -> Any:
|
|
"""Main method for tool execution."""
|
|
parsed_args = self._parse_args(input)
|
|
|
|
if self.has_reached_max_usage_count():
|
|
raise ToolUsageLimitExceeded(
|
|
f"Tool '{self.name}' has reached its maximum usage limit of {self.max_usage_count}. You should not use the {self.name} tool again."
|
|
)
|
|
|
|
self._increment_usage_count()
|
|
|
|
if inspect.iscoroutinefunction(self.func):
|
|
result = asyncio.run(self.func(**parsed_args, **kwargs))
|
|
return result
|
|
|
|
try:
|
|
result = self.func(**parsed_args, **kwargs)
|
|
except Exception:
|
|
raise
|
|
|
|
result = self.func(**parsed_args, **kwargs)
|
|
|
|
if asyncio.iscoroutine(result):
|
|
return asyncio.run(result)
|
|
|
|
return result
|
|
|
|
def has_reached_max_usage_count(self) -> bool:
|
|
"""Check if the tool has reached its maximum usage count."""
|
|
return (
|
|
self.max_usage_count is not None
|
|
and self.current_usage_count >= self.max_usage_count
|
|
)
|
|
|
|
def _increment_usage_count(self) -> None:
|
|
"""Increment the usage count."""
|
|
self.current_usage_count += 1
|
|
if self._original_tool is not None:
|
|
self._original_tool.current_usage_count = self.current_usage_count
|
|
|
|
@property
|
|
def args(self) -> dict:
|
|
"""Get the tool's input arguments schema."""
|
|
return self.args_schema.model_json_schema()["properties"]
|
|
|
|
def __repr__(self) -> str:
|
|
return (
|
|
f"CrewStructuredTool(name='{self.name}', description='{self.description}')"
|
|
)
|