Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
b579cb2cce Address code review comments and fix lint error
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-25 13:36:50 +00:00
Devin AI
a46fa50f73 Fix issue #2219: Handle callable agents in CrewBase decorator
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-25 13:24:35 +00:00
8 changed files with 144 additions and 74 deletions

View File

@@ -53,8 +53,6 @@ openpyxl = [
mem0 = ["mem0ai>=0.1.29"]
docling = [
"docling>=2.12.0",
# Required for transformers compatibility
"tokenizers>=0.21,<0.22",
]
[tool.uv]

View File

@@ -134,6 +134,19 @@ class BaseAgent(ABC, BaseModel):
@model_validator(mode="before")
@classmethod
def process_model_config(cls, values):
"""
Process model configuration values.
Args:
values: Configuration values or callable agent
When using CrewBase decorator, this can be a callable that returns an agent
Returns:
Processed configuration or callable agent
"""
# Handle case where values is a function (can happen with CrewBase decorator)
if callable(values) and not isinstance(values, dict):
return values
return process_config(values, cls)
@field_validator("tools")

View File

@@ -15,15 +15,8 @@ from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to Markdown or JSON
This will auto support PDF, DOCX, TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
Requirements:
- Install with: `pip install crewai[docling]`
- Requires tokenizers>=0.21,<0.22 for transformers compatibility
Notes:
- This is an optional dependency, only needed for document processing features.
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""
_logger: Logger = Logger(verbose=True)

View File

@@ -65,6 +65,27 @@ def cache_handler(func):
return memoize(func)
def _resolve_agent(task_instance):
"""
Resolve an agent from a task instance.
If the agent is a callable (e.g., a method from CrewBase), call it to get the agent instance.
Args:
task_instance: The task instance containing the agent
Returns:
The resolved agent instance or None if no agent is present
"""
if not hasattr(task_instance, 'agent') or not task_instance.agent:
return None
if callable(task_instance.agent) and not isinstance(task_instance.agent, type):
return task_instance.agent()
return task_instance.agent
def crew(func) -> Callable[..., Crew]:
@wraps(func)
@@ -79,7 +100,14 @@ def crew(func) -> Callable[..., Crew]:
# Instantiate tasks in order
for task_name, task_method in tasks:
# Get the task instance
task_instance = task_method(self)
# Resolve the agent
agent = _resolve_agent(task_instance)
if agent:
task_instance.agent = agent
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance and agent_instance.role not in agent_roles:

View File

@@ -61,6 +61,25 @@ class Task(BaseModel):
output_pydantic: Pydantic model for task output.
tools: List of tools/resources limited for task execution.
"""
def __init__(self, **data):
# Handle case where agent is a callable (can happen with CrewBase decorator)
if 'agent' in data and callable(data['agent']) and not isinstance(data['agent'], type):
try:
# Call the agent method to get the agent instance
agent = data['agent']()
# Verify that the agent is a valid instance
from crewai.agents.agent_builder.base_agent import BaseAgent
if agent is not None and not isinstance(agent, BaseAgent):
raise ValueError(f"Expected BaseAgent instance, got {type(agent)}")
data['agent'] = agent
except Exception as e:
raise ValueError(f"Failed to initialize agent from callable: {e}")
# Call the parent class __init__ method
super().__init__(**data)
__hash__ = object.__hash__ # type: ignore
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)

View File

@@ -0,0 +1,51 @@
import unittest
from crewai import Agent, Task
class TestTaskInitFix(unittest.TestCase):
"""Test the fix for issue #2219 where agent methods are not handled correctly in tasks."""
def test_task_init_handles_callable_agent(self):
"""Test that the Task.__init__ method correctly handles callable agents."""
# Create an agent instance
agent_instance = Agent(
role="Test Agent",
goal="Test Goal",
backstory="Test Backstory"
)
# Create a callable that returns the agent instance
def callable_agent():
return agent_instance
# Create a task with the callable agent
task = Task(
description="Test Task",
expected_output="Test Output",
agent=callable_agent
)
# Verify that the agent in the task is an instance, not a callable
self.assertIsInstance(task.agent, Agent)
self.assertEqual(task.agent.role, "Test Agent")
self.assertIs(task.agent, agent_instance)
def test_task_init_handles_invalid_callable_agent(self):
"""Test that the Task.__init__ method correctly handles invalid callable agents."""
# Create a callable that returns an invalid agent (not an Agent instance)
def invalid_callable_agent():
return "Not an agent"
# Create a task with the invalid callable agent
with self.assertRaises(ValueError) as context:
task = Task(
description="Test Task",
expected_output="Test Output",
agent=invalid_callable_agent
)
# Verify that the error message is correct
self.assertIn("Expected BaseAgent instance", str(context.exception))

View File

@@ -1,26 +0,0 @@
"""Test suite to verify compatibility between tokenizers and transformers packages.
Ensures the installed tokenizers version meets requirements and can work effectively with transformers.
"""
import pytest
def test_tokenizers_transformers_compatibility():
"""Test that the installed tokenizers version is compatible with transformers."""
try:
import tokenizers
import transformers
except ImportError:
pytest.skip("tokenizers or transformers not installed")
tokenizers_version = tokenizers.__version__
transformers_version = transformers.__version__
tokenizers_major, tokenizers_minor, _ = map(int, tokenizers_version.split('.'))
assert tokenizers_major == 0, f"Expected tokenizers major version 0, got {tokenizers_major}"
assert tokenizers_minor >= 21, f"Expected tokenizers minor version >=21, got {tokenizers_minor}"
assert tokenizers_minor < 22, f"Expected tokenizers minor version <22, got {tokenizers_minor}"
print(f"Tokenizers version: {tokenizers_version}")
print(f"Transformers version: {transformers_version}")

68
uv.lock generated
View File

@@ -1,18 +1,10 @@
version = 1
requires-python = ">=3.10, <3.13"
resolution-markers = [
"python_full_version < '3.11' and sys_platform == 'darwin'",
"python_full_version < '3.11' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version < '3.11' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version < '3.11' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version == '3.11.*' and sys_platform == 'darwin'",
"python_full_version == '3.11.*' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version == '3.11.*' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version == '3.11.*' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12' and python_full_version < '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12' and python_full_version < '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version >= '3.12.4' and sys_platform == 'darwin'",
"python_full_version >= '3.12.4' and platform_machine == 'aarch64' and sys_platform == 'linux'",
"(python_full_version >= '3.12.4' and platform_machine != 'aarch64' and sys_platform == 'linux') or (python_full_version >= '3.12.4' and sys_platform != 'darwin' and sys_platform != 'linux')",
"python_full_version < '3.11'",
"python_full_version == '3.11.*'",
"python_full_version >= '3.12' and python_full_version < '3.12.4'",
"python_full_version >= '3.12.4'",
]
[[package]]
@@ -308,7 +300,7 @@ name = "build"
version = "1.2.2.post1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "(os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "colorama", marker = "os_name == 'nt'" },
{ name = "importlib-metadata", marker = "python_full_version < '3.10.2'" },
{ name = "packaging" },
{ name = "pyproject-hooks" },
@@ -543,7 +535,7 @@ name = "click"
version = "8.1.7"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/96/d3/f04c7bfcf5c1862a2a5b845c6b2b360488cf47af55dfa79c98f6a6bf98b5/click-8.1.7.tar.gz", hash = "sha256:ca9853ad459e787e2192211578cc907e7594e294c7ccc834310722b41b9ca6de", size = 336121 }
wheels = [
@@ -650,6 +642,7 @@ tools = [
[package.dev-dependencies]
dev = [
{ name = "cairosvg" },
{ name = "crewai-tools" },
{ name = "mkdocs" },
{ name = "mkdocs-material" },
{ name = "mkdocs-material-extensions" },
@@ -703,6 +696,7 @@ requires-dist = [
[package.metadata.requires-dev]
dev = [
{ name = "cairosvg", specifier = ">=2.7.1" },
{ name = "crewai-tools", specifier = ">=0.17.0" },
{ name = "mkdocs", specifier = ">=1.4.3" },
{ name = "mkdocs-material", specifier = ">=9.5.7" },
{ name = "mkdocs-material-extensions", specifier = ">=1.3.1" },
@@ -2468,7 +2462,7 @@ version = "1.6.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "click" },
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
{ name = "ghp-import" },
{ name = "jinja2" },
{ name = "markdown" },
@@ -2649,7 +2643,7 @@ version = "2.10.2"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pygments" },
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
{ name = "tqdm" },
]
sdist = { url = "https://files.pythonhosted.org/packages/3a/93/80ac75c20ce54c785648b4ed363c88f148bf22637e10c9863db4fbe73e74/mpire-2.10.2.tar.gz", hash = "sha256:f66a321e93fadff34585a4bfa05e95bd946cf714b442f51c529038eb45773d97", size = 271270 }
@@ -2896,7 +2890,7 @@ name = "nvidia-cudnn-cu12"
version = "9.1.0.70"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/9f/fd/713452cd72343f682b1c7b9321e23829f00b842ceaedcda96e742ea0b0b3/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl", hash = "sha256:165764f44ef8c61fcdfdfdbe769d687e06374059fbb388b6c89ecb0e28793a6f", size = 664752741 },
@@ -2923,9 +2917,9 @@ name = "nvidia-cusolver-cu12"
version = "11.4.5.107"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-cublas-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-cusparse-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/bc/1d/8de1e5c67099015c834315e333911273a8c6aaba78923dd1d1e25fc5f217/nvidia_cusolver_cu12-11.4.5.107-py3-none-manylinux1_x86_64.whl", hash = "sha256:8a7ec542f0412294b15072fa7dab71d31334014a69f953004ea7a118206fe0dd", size = 124161928 },
@@ -2936,7 +2930,7 @@ name = "nvidia-cusparse-cu12"
version = "12.1.0.106"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "nvidia-nvjitlink-cu12", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/65/5b/cfaeebf25cd9fdec14338ccb16f6b2c4c7fa9163aefcf057d86b9cc248bb/nvidia_cusparse_cu12-12.1.0.106-py3-none-manylinux1_x86_64.whl", hash = "sha256:f3b50f42cf363f86ab21f720998517a659a48131e8d538dc02f8768237bd884c", size = 195958278 },
@@ -3486,7 +3480,7 @@ name = "portalocker"
version = "2.10.1"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "pywin32", marker = "sys_platform == 'win32'" },
{ name = "pywin32", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/ed/d3/c6c64067759e87af98cc668c1cc75171347d0f1577fab7ca3749134e3cd4/portalocker-2.10.1.tar.gz", hash = "sha256:ef1bf844e878ab08aee7e40184156e1151f228f103aa5c6bd0724cc330960f8f", size = 40891 }
wheels = [
@@ -5028,19 +5022,19 @@ dependencies = [
{ name = "fsspec" },
{ name = "jinja2" },
{ name = "networkx" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "nvidia-cublas-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-cupti-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-nvrtc-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cuda-runtime-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cudnn-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cufft-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-curand-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusolver-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-cusparse-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nccl-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "nvidia-nvtx-cu12", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "sympy" },
{ name = "triton", marker = "platform_machine == 'x86_64' and sys_platform == 'linux'" },
{ name = "triton", marker = "platform_machine == 'x86_64' and platform_system == 'Linux'" },
{ name = "typing-extensions" },
]
wheels = [
@@ -5087,7 +5081,7 @@ name = "tqdm"
version = "4.66.5"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "colorama", marker = "sys_platform == 'win32'" },
{ name = "colorama", marker = "platform_system == 'Windows'" },
]
sdist = { url = "https://files.pythonhosted.org/packages/58/83/6ba9844a41128c62e810fddddd72473201f3eacde02046066142a2d96cc5/tqdm-4.66.5.tar.gz", hash = "sha256:e1020aef2e5096702d8a025ac7d16b1577279c9d63f8375b63083e9a5f0fcbad", size = 169504 }
wheels = [
@@ -5130,7 +5124,7 @@ version = "0.27.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "attrs" },
{ name = "cffi", marker = "(implementation_name != 'pypy' and os_name == 'nt' and platform_machine != 'aarch64' and sys_platform == 'linux') or (implementation_name != 'pypy' and os_name == 'nt' and sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "cffi", marker = "implementation_name != 'pypy' and os_name == 'nt'" },
{ name = "exceptiongroup", marker = "python_full_version < '3.11'" },
{ name = "idna" },
{ name = "outcome" },
@@ -5161,7 +5155,7 @@ name = "triton"
version = "3.0.0"
source = { registry = "https://pypi.org/simple" }
dependencies = [
{ name = "filelock", marker = "(platform_machine != 'aarch64' and sys_platform == 'linux') or (sys_platform != 'darwin' and sys_platform != 'linux')" },
{ name = "filelock", marker = "(platform_machine != 'aarch64' and platform_system != 'Darwin') or (platform_system != 'Darwin' and platform_system != 'Linux')" },
]
wheels = [
{ url = "https://files.pythonhosted.org/packages/45/27/14cc3101409b9b4b9241d2ba7deaa93535a217a211c86c4cc7151fb12181/triton-3.0.0-1-cp310-cp310-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:e1efef76935b2febc365bfadf74bcb65a6f959a9872e5bddf44cc9e0adce1e1a", size = 209376304 },