diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 102f22881..a0158f646 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -52,7 +52,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.base_tool import BaseTool, Tool from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController -from crewai.utilities.constants import TRAINING_DATA_FILE +from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.events.crew_events import ( @@ -478,7 +478,7 @@ class Crew(FlowTrackable, BaseModel): separated by a synchronous task. """ for i, task in enumerate(self.tasks): - if task.async_execution and task.context: + if task.async_execution and isinstance(task.context, list): for context_task in task.context: if context_task.async_execution: for j in range(i - 1, -1, -1): @@ -496,7 +496,7 @@ class Crew(FlowTrackable, BaseModel): task_indices = {id(task): i for i, task in enumerate(self.tasks)} for task in self.tasks: - if task.context: + if isinstance(task.context, list): for context_task in task.context: if id(context_task) not in task_indices: continue # Skip context tasks not in the main tasks list @@ -1034,11 +1034,14 @@ class Crew(FlowTrackable, BaseModel): ) return cast(List[BaseTool], tools) - def _get_context(self, task: Task, task_outputs: List[TaskOutput]): + def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str: + if not task.context: + return "" + context = ( - aggregate_raw_outputs_from_tasks(task.context) - if task.context - else aggregate_raw_outputs_from_task_outputs(task_outputs) + aggregate_raw_outputs_from_task_outputs(task_outputs) + if task.context is NOT_SPECIFIED + else aggregate_raw_outputs_from_tasks(task.context) ) return context @@ -1226,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel): task_mapping[task.key] = cloned_task for cloned_task, original_task in zip(cloned_tasks, self.tasks): - if original_task.context: + if isinstance(original_task.context, list): cloned_context = [ task_mapping[context_task.key] for context_task in original_task.context diff --git a/src/crewai/task.py b/src/crewai/task.py index e4a25f438..5204ba92e 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -41,6 +41,7 @@ from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput from crewai.tools.base_tool import BaseTool from crewai.utilities.config import process_config +from crewai.utilities.constants import NOT_SPECIFIED from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.events import ( TaskCompletedEvent, @@ -97,7 +98,7 @@ class Task(BaseModel): ) context: Optional[List["Task"]] = Field( description="Other tasks that will have their output used as context for this task.", - default=None, + default=NOT_SPECIFIED, ) async_execution: Optional[bool] = Field( description="Whether the task should be executed asynchronously or not.", @@ -643,7 +644,7 @@ class Task(BaseModel): cloned_context = ( [task_mapping[context_task.key] for context_task in self.context] - if self.context + if isinstance(self.context, list) else None ) diff --git a/src/crewai/telemetry/telemetry.py b/src/crewai/telemetry/telemetry.py index 142cafb2a..ffd88a330 100644 --- a/src/crewai/telemetry/telemetry.py +++ b/src/crewai/telemetry/telemetry.py @@ -232,7 +232,7 @@ class Telemetry: "agent_key": task.agent.key if task.agent else None, "context": ( [task.description for task in task.context] - if task.context + if isinstance(task.context, list) else None ), "tools_names": [ @@ -748,7 +748,7 @@ class Telemetry: "agent_key": task.agent.key if task.agent else None, "context": ( [task.description for task in task.context] - if task.context + if isinstance(task.context, list) else None ), "tools_names": [ diff --git a/src/crewai/utilities/constants.py b/src/crewai/utilities/constants.py index 9ff10f1d4..4dbace270 100644 --- a/src/crewai/utilities/constants.py +++ b/src/crewai/utilities/constants.py @@ -5,3 +5,14 @@ KNOWLEDGE_DIRECTORY = "knowledge" MAX_LLM_RETRY = 3 MAX_FILE_NAME_LENGTH = 255 EMITTER_COLOR = "bold_blue" + + +class _NotSpecified: + def __repr__(self): + return "NOT_SPECIFIED" + + +# Sentinel value used to detect when no value has been explicitly provided. +# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows +# us to distinguish between "not passed at all" and "explicitly passed None" or "[]". +NOT_SPECIFIED = _NotSpecified() diff --git a/src/crewai/utilities/formatter.py b/src/crewai/utilities/formatter.py index 19b2a74f9..9c2da70c6 100644 --- a/src/crewai/utilities/formatter.py +++ b/src/crewai/utilities/formatter.py @@ -1,5 +1,7 @@ import re -from typing import TYPE_CHECKING, List +from typing import TYPE_CHECKING, List, Type + +from crewai.utilities.constants import NOT_SPECIFIED if TYPE_CHECKING: from crewai.task import Task @@ -17,6 +19,11 @@ def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str: """Generate string context from the tasks.""" - task_outputs = [task.output for task in tasks if task.output is not None] + + task_outputs = ( + [task.output for task in tasks if task.output is not None] + if isinstance(tasks, list) + else [] + ) return aggregate_raw_outputs_from_task_outputs(task_outputs) diff --git a/tests/crew_test.py b/tests/crew_test.py index a4e4e61df..4ca1fe878 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -6,7 +6,7 @@ import os import tempfile from concurrent.futures import Future from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pydantic_core import pytest @@ -3141,6 +3141,30 @@ def test_replay_with_context(): assert crew.tasks[1].context[0].output.raw == "context raw output" +def test_replay_with_context_set_to_nullable(): + agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal") + task1 = Task( + description="Context Task", expected_output="Say Task Output", agent=agent + ) + task2 = Task( + description="Test Task", expected_output="Say Hi", agent=agent, context=[] + ) + task3 = Task( + description="Test Task 3", expected_output="Say Hi", agent=agent, context=None + ) + + crew = Crew(agents=[agent], tasks=[task1, task2, task3], process=Process.sequential) + with patch("crewai.task.Task.execute_sync") as mock_execute_task: + mock_execute_task.return_value = TaskOutput( + description="Test Task Output", + raw="test raw output", + agent="test_agent", + ) + crew.kickoff() + + mock_execute_task.assert_called_with(agent=ANY, context="", tools=ANY) + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_replay_with_invalid_task_id(): agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")