diff --git a/reproduce_issue_3226.py b/reproduce_issue_3226.py index 6a0709dbb..93fef78f4 100644 --- a/reproduce_issue_3226.py +++ b/reproduce_issue_3226.py @@ -6,7 +6,6 @@ This script tests all the failing patterns mentioned in the issue. import sys import traceback -from typing import Any def test_function_tool(): """Test 1: Function Tool with @tool decorator""" @@ -20,13 +19,14 @@ def test_function_tool(): """Fetch logs from New Relic based on query""" return f"Logs for query: {query}" - teacher = Agent( + agent = Agent( role='CrashFetcher', goal='Extract logs', backstory='An agent that fetches logs', tools=[fetch_logs], allow_delegation=False ) + assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}" print("✅ Function tool with @tool decorator: SUCCESS") return True except Exception as e: @@ -49,13 +49,14 @@ def test_dict_tool(): 'func': fetch_logs_func } - teacher = Agent( + agent = Agent( role='CrashFetcher', goal='Extract logs', backstory='An agent that fetches logs', tools=[fetch_logs_dict], allow_delegation=False ) + assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}" print("✅ Dict-based tool: SUCCESS") return True except Exception as e: @@ -77,13 +78,14 @@ def test_basetool_class(): def _run(self, query: str) -> str: return f"Logs for query: {query}" - teacher = Agent( + agent = Agent( role='CrashFetcher', goal='Extract logs', backstory='An agent that fetches logs', tools=[FetchLogsTool()], allow_delegation=False ) + assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}" print("✅ BaseTool class inheritance: SUCCESS") return True except Exception as e: @@ -101,13 +103,14 @@ def test_direct_function(): """Fetch logs from New Relic based on query""" return f"Logs for query: {query}" - teacher = Agent( + agent = Agent( role='CrashFetcher', goal='Extract logs', backstory='An agent that fetches logs', tools=[fetch_logs], allow_delegation=False ) + assert len(agent.tools) == 1, f"Expected 1 tool, got {len(agent.tools)}" print("✅ Direct function assignment: SUCCESS") return True except Exception as e: @@ -125,7 +128,7 @@ def main(): results.append(test_basetool_class()) results.append(test_direct_function()) - print(f"\n=== SUMMARY ===") + print("\n=== SUMMARY ===") passed = sum(results) total = len(results) print(f"Tests passed: {passed}/{total}") diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index c3778d0bf..7806c7f15 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -2,7 +2,7 @@ import uuid from abc import ABC, abstractmethod from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, TypeVar +from typing import Any, Callable, Dict, List, Optional, TypeVar, cast from pydantic import ( UUID4, @@ -25,7 +25,6 @@ from crewai.security.security_config import SecurityConfig from crewai.tools.base_tool import BaseTool, Tool from crewai.utilities import I18N, Logger, RPMController from crewai.utilities.config import process_config -from crewai.utilities.converter import Converter from crewai.utilities.string_utils import interpolate_only T = TypeVar("T", bound="BaseAgent") @@ -186,7 +185,7 @@ class BaseAgent(ABC, BaseModel): processed_tools.append(tool_item) elif callable(tool_item): if hasattr(tool_item, '__doc__') and tool_item.__doc__: - converted_tool = tool(tool_item) + converted_tool = cast(BaseTool, tool(tool_item)) processed_tools.append(converted_tool) else: raise ValueError( diff --git a/src/crewai/utilities/agent_utils.py b/src/crewai/utilities/agent_utils.py index 700ecd2e1..a6f7f2af8 100644 --- a/src/crewai/utilities/agent_utils.py +++ b/src/crewai/utilities/agent_utils.py @@ -11,7 +11,6 @@ from crewai.agents.parser import ( ) from crewai.llm import LLM from crewai.llms.base_llm import BaseLLM -from crewai.tools import BaseTool as CrewAITool from crewai.tools.base_tool import BaseTool from crewai.tools.structured_tool import CrewStructuredTool from crewai.tools.tool_types import ToolResult diff --git a/tests/test_custom_tools.py b/tests/test_custom_tools.py index baebcc46e..050061ca4 100644 --- a/tests/test_custom_tools.py +++ b/tests/test_custom_tools.py @@ -4,9 +4,8 @@ This addresses issue #3226 where custom tool registration was broken in CrewAI 0 """ import pytest -from typing import Any from crewai import Agent -from crewai.tools import BaseTool, tool, Tool +from crewai.tools import BaseTool, tool class TestCustomToolPatterns: