From c403497cf44b7e08712502b289be71eaf9a050b7 Mon Sep 17 00:00:00 2001 From: Lucas Gomide Date: Wed, 14 May 2025 07:36:32 -0300 Subject: [PATCH] feat: support to set an empty context to the Task (#2793) * feat: support to set an empty context to the Task * sytle: fix linter issues --- src/crewai/crew.py | 19 ++++++++++-------- src/crewai/task.py | 6 +++--- src/crewai/telemetry/telemetry.py | 28 +++++++++++++-------------- src/crewai/utilities/constants.py | 11 +++++++++++ src/crewai/utilities/formatter.py | 9 +++++++-- tests/crew_test.py | 32 +++++++++++++++++++++++++------ 6 files changed, 72 insertions(+), 33 deletions(-) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 102f22881..a0158f646 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -52,7 +52,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools from crewai.tools.base_tool import BaseTool, Tool from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController -from crewai.utilities.constants import TRAINING_DATA_FILE +from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.events.crew_events import ( @@ -478,7 +478,7 @@ class Crew(FlowTrackable, BaseModel): separated by a synchronous task. """ for i, task in enumerate(self.tasks): - if task.async_execution and task.context: + if task.async_execution and isinstance(task.context, list): for context_task in task.context: if context_task.async_execution: for j in range(i - 1, -1, -1): @@ -496,7 +496,7 @@ class Crew(FlowTrackable, BaseModel): task_indices = {id(task): i for i, task in enumerate(self.tasks)} for task in self.tasks: - if task.context: + if isinstance(task.context, list): for context_task in task.context: if id(context_task) not in task_indices: continue # Skip context tasks not in the main tasks list @@ -1034,11 +1034,14 @@ class Crew(FlowTrackable, BaseModel): ) return cast(List[BaseTool], tools) - def _get_context(self, task: Task, task_outputs: List[TaskOutput]): + def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str: + if not task.context: + return "" + context = ( - aggregate_raw_outputs_from_tasks(task.context) - if task.context - else aggregate_raw_outputs_from_task_outputs(task_outputs) + aggregate_raw_outputs_from_task_outputs(task_outputs) + if task.context is NOT_SPECIFIED + else aggregate_raw_outputs_from_tasks(task.context) ) return context @@ -1226,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel): task_mapping[task.key] = cloned_task for cloned_task, original_task in zip(cloned_tasks, self.tasks): - if original_task.context: + if isinstance(original_task.context, list): cloned_context = [ task_mapping[context_task.key] for context_task in original_task.context diff --git a/src/crewai/task.py b/src/crewai/task.py index e4a25f438..754fab491 100644 --- a/src/crewai/task.py +++ b/src/crewai/task.py @@ -2,7 +2,6 @@ import datetime import inspect import json import logging -import re import threading import uuid from concurrent.futures import Future @@ -41,6 +40,7 @@ from crewai.tasks.output_format import OutputFormat from crewai.tasks.task_output import TaskOutput from crewai.tools.base_tool import BaseTool from crewai.utilities.config import process_config +from crewai.utilities.constants import NOT_SPECIFIED from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.events import ( TaskCompletedEvent, @@ -97,7 +97,7 @@ class Task(BaseModel): ) context: Optional[List["Task"]] = Field( description="Other tasks that will have their output used as context for this task.", - default=None, + default=NOT_SPECIFIED, ) async_execution: Optional[bool] = Field( description="Whether the task should be executed asynchronously or not.", @@ -643,7 +643,7 @@ class Task(BaseModel): cloned_context = ( [task_mapping[context_task.key] for context_task in self.context] - if self.context + if isinstance(self.context, list) else None ) diff --git a/src/crewai/telemetry/telemetry.py b/src/crewai/telemetry/telemetry.py index 142cafb2a..e22d757cd 100644 --- a/src/crewai/telemetry/telemetry.py +++ b/src/crewai/telemetry/telemetry.py @@ -10,6 +10,18 @@ from contextlib import contextmanager from importlib.metadata import version from typing import TYPE_CHECKING, Any, Optional +from opentelemetry import trace +from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( + OTLPSpanExporter, +) +from opentelemetry.sdk.resources import SERVICE_NAME, Resource +from opentelemetry.sdk.trace import TracerProvider +from opentelemetry.sdk.trace.export import ( + BatchSpanProcessor, + SpanExportResult, +) +from opentelemetry.trace import Span, Status, StatusCode + from crewai.telemetry.constants import ( CREWAI_TELEMETRY_BASE_URL, CREWAI_TELEMETRY_SERVICE_NAME, @@ -25,18 +37,6 @@ def suppress_warnings(): yield -from opentelemetry import trace # noqa: E402 -from opentelemetry.exporter.otlp.proto.http.trace_exporter import ( - OTLPSpanExporter, # noqa: E402 -) -from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402 -from opentelemetry.sdk.trace import TracerProvider # noqa: E402 -from opentelemetry.sdk.trace.export import ( # noqa: E402 - BatchSpanProcessor, - SpanExportResult, -) -from opentelemetry.trace import Span, Status, StatusCode # noqa: E402 - if TYPE_CHECKING: from crewai.crew import Crew from crewai.task import Task @@ -232,7 +232,7 @@ class Telemetry: "agent_key": task.agent.key if task.agent else None, "context": ( [task.description for task in task.context] - if task.context + if isinstance(task.context, list) else None ), "tools_names": [ @@ -748,7 +748,7 @@ class Telemetry: "agent_key": task.agent.key if task.agent else None, "context": ( [task.description for task in task.context] - if task.context + if isinstance(task.context, list) else None ), "tools_names": [ diff --git a/src/crewai/utilities/constants.py b/src/crewai/utilities/constants.py index 9ff10f1d4..4dbace270 100644 --- a/src/crewai/utilities/constants.py +++ b/src/crewai/utilities/constants.py @@ -5,3 +5,14 @@ KNOWLEDGE_DIRECTORY = "knowledge" MAX_LLM_RETRY = 3 MAX_FILE_NAME_LENGTH = 255 EMITTER_COLOR = "bold_blue" + + +class _NotSpecified: + def __repr__(self): + return "NOT_SPECIFIED" + + +# Sentinel value used to detect when no value has been explicitly provided. +# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows +# us to distinguish between "not passed at all" and "explicitly passed None" or "[]". +NOT_SPECIFIED = _NotSpecified() diff --git a/src/crewai/utilities/formatter.py b/src/crewai/utilities/formatter.py index 19b2a74f9..fd2611472 100644 --- a/src/crewai/utilities/formatter.py +++ b/src/crewai/utilities/formatter.py @@ -1,6 +1,6 @@ -import re from typing import TYPE_CHECKING, List + if TYPE_CHECKING: from crewai.task import Task from crewai.tasks.task_output import TaskOutput @@ -17,6 +17,11 @@ def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) -> def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str: """Generate string context from the tasks.""" - task_outputs = [task.output for task in tasks if task.output is not None] + + task_outputs = ( + [task.output for task in tasks if task.output is not None] + if isinstance(tasks, list) + else [] + ) return aggregate_raw_outputs_from_task_outputs(task_outputs) diff --git a/tests/crew_test.py b/tests/crew_test.py index a4e4e61df..7c242c825 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -2,22 +2,18 @@ import hashlib import json -import os -import tempfile from concurrent.futures import Future from unittest import mock -from unittest.mock import MagicMock, patch +from unittest.mock import ANY, MagicMock, patch import pydantic_core import pytest from crewai.agent import Agent from crewai.agents import CacheHandler -from crewai.agents.cache import CacheHandler -from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.crew import Crew from crewai.crews.crew_output import CrewOutput -from crewai.flow import Flow, listen, start +from crewai.flow import Flow, start from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.llm import LLM from crewai.memory.contextual.contextual_memory import ContextualMemory @@ -3141,6 +3137,30 @@ def test_replay_with_context(): assert crew.tasks[1].context[0].output.raw == "context raw output" +def test_replay_with_context_set_to_nullable(): + agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal") + task1 = Task( + description="Context Task", expected_output="Say Task Output", agent=agent + ) + task2 = Task( + description="Test Task", expected_output="Say Hi", agent=agent, context=[] + ) + task3 = Task( + description="Test Task 3", expected_output="Say Hi", agent=agent, context=None + ) + + crew = Crew(agents=[agent], tasks=[task1, task2, task3], process=Process.sequential) + with patch("crewai.task.Task.execute_sync") as mock_execute_task: + mock_execute_task.return_value = TaskOutput( + description="Test Task Output", + raw="test raw output", + agent="test_agent", + ) + crew.kickoff() + + mock_execute_task.assert_called_with(agent=ANY, context="", tools=ANY) + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_replay_with_invalid_task_id(): agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")