diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 207a1769a..c0d8041c1 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -134,6 +134,9 @@ class BaseAgent(ABC, BaseModel): @model_validator(mode="before") @classmethod def process_model_config(cls, values): + # Handle case where values is a function (can happen with CrewBase decorator) + if callable(values) and not isinstance(values, dict): + return values return process_config(values, cls) @field_validator("tools") diff --git a/src/crewai/project/annotations.py b/src/crewai/project/annotations.py index bf0051c4d..86a77bd63 100644 --- a/src/crewai/project/annotations.py +++ b/src/crewai/project/annotations.py @@ -79,7 +79,14 @@ def crew(func) -> Callable[..., Crew]: # Instantiate tasks in order for task_name, task_method in tasks: + # Get the task instance task_instance = task_method(self) + + # Handle case where agent is a method (function) from CrewBase + if hasattr(task_instance, 'agent') and task_instance.agent and callable(task_instance.agent) and not isinstance(task_instance.agent, type): + # Call the agent method to get the agent instance + task_instance.agent = task_instance.agent() + instantiated_tasks.append(task_instance) agent_instance = getattr(task_instance, "agent", None) if agent_instance and agent_instance.role not in agent_roles: diff --git a/src/crewai/task.py b/src/crewai/task.py index 30ab79c00..f20270ff0 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -61,6 +61,15 @@ class Task(BaseModel): output_pydantic: Pydantic model for task output. tools: List of tools/resources limited for task execution. """ + + def __init__(self, **data): + # Handle case where agent is a callable (can happen with CrewBase decorator) + if 'agent' in data and callable(data['agent']) and not isinstance(data['agent'], type): + # Call the agent method to get the agent instance + data['agent'] = data['agent']() + + # Call the parent class __init__ method + super().__init__(**data) __hash__ = object.__hash__ # type: ignore logger: ClassVar[logging.Logger] = logging.getLogger(__name__) diff --git a/tests/test_crewbase_agent_method.py b/tests/test_crewbase_agent_method.py new file mode 100644 index 000000000..cfb14c523 --- /dev/null +++ b/tests/test_crewbase_agent_method.py @@ -0,0 +1,32 @@ +import unittest +from crewai import Agent, Task + + +class TestTaskInitFix(unittest.TestCase): + """Test the fix for issue #2219 where agent methods are not handled correctly in tasks.""" + + def test_task_init_handles_callable_agent(self): + """Test that the Task.__init__ method correctly handles callable agents.""" + + # Create an agent instance + agent_instance = Agent( + role="Test Agent", + goal="Test Goal", + backstory="Test Backstory" + ) + + # Create a callable that returns the agent instance + def callable_agent(): + return agent_instance + + # Create a task with the callable agent + task = Task( + description="Test Task", + expected_output="Test Output", + agent=callable_agent + ) + + # Verify that the agent in the task is an instance, not a callable + self.assertIsInstance(task.agent, Agent) + self.assertEqual(task.agent.role, "Test Agent") + self.assertIs(task.agent, agent_instance)