Compare commits

...

4 Commits

Author SHA1 Message Date
Lucas Gomide
ebcda8b2ec Merge branch 'main' into lg-support-set-task-context 2025-05-13 18:13:35 -03:00
Lucas Gomide
fced8ba47f sytle: fix linter issues 2025-05-12 11:53:57 -03:00
Lucas Gomide
7204910da4 Merge branch 'main' into lg-support-set-task-context 2025-05-12 09:47:52 -03:00
Lucas Gomide
971a90f534 feat: support to set an empty context to the Task 2025-05-10 09:46:21 -03:00
6 changed files with 72 additions and 33 deletions

View File

@@ -52,7 +52,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.events.crew_events import ( from crewai.utilities.events.crew_events import (
@@ -478,7 +478,7 @@ class Crew(FlowTrackable, BaseModel):
separated by a synchronous task. separated by a synchronous task.
""" """
for i, task in enumerate(self.tasks): for i, task in enumerate(self.tasks):
if task.async_execution and task.context: if task.async_execution and isinstance(task.context, list):
for context_task in task.context: for context_task in task.context:
if context_task.async_execution: if context_task.async_execution:
for j in range(i - 1, -1, -1): for j in range(i - 1, -1, -1):
@@ -496,7 +496,7 @@ class Crew(FlowTrackable, BaseModel):
task_indices = {id(task): i for i, task in enumerate(self.tasks)} task_indices = {id(task): i for i, task in enumerate(self.tasks)}
for task in self.tasks: for task in self.tasks:
if task.context: if isinstance(task.context, list):
for context_task in task.context: for context_task in task.context:
if id(context_task) not in task_indices: if id(context_task) not in task_indices:
continue # Skip context tasks not in the main tasks list continue # Skip context tasks not in the main tasks list
@@ -1034,11 +1034,14 @@ class Crew(FlowTrackable, BaseModel):
) )
return cast(List[BaseTool], tools) return cast(List[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]): def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
if not task.context:
return ""
context = ( context = (
aggregate_raw_outputs_from_tasks(task.context) aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_task_outputs(task_outputs) else aggregate_raw_outputs_from_tasks(task.context)
) )
return context return context
@@ -1226,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel):
task_mapping[task.key] = cloned_task task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks): for cloned_task, original_task in zip(cloned_tasks, self.tasks):
if original_task.context: if isinstance(original_task.context, list):
cloned_context = [ cloned_context = [
task_mapping[context_task.key] task_mapping[context_task.key]
for context_task in original_task.context for context_task in original_task.context

View File

@@ -2,7 +2,6 @@ import datetime
import inspect import inspect
import json import json
import logging import logging
import re
import threading import threading
import uuid import uuid
from concurrent.futures import Future from concurrent.futures import Future
@@ -41,6 +40,7 @@ from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED
from crewai.utilities.converter import Converter, convert_to_model from crewai.utilities.converter import Converter, convert_to_model
from crewai.utilities.events import ( from crewai.utilities.events import (
TaskCompletedEvent, TaskCompletedEvent,
@@ -97,7 +97,7 @@ class Task(BaseModel):
) )
context: Optional[List["Task"]] = Field( context: Optional[List["Task"]] = Field(
description="Other tasks that will have their output used as context for this task.", description="Other tasks that will have their output used as context for this task.",
default=None, default=NOT_SPECIFIED,
) )
async_execution: Optional[bool] = Field( async_execution: Optional[bool] = Field(
description="Whether the task should be executed asynchronously or not.", description="Whether the task should be executed asynchronously or not.",
@@ -643,7 +643,7 @@ class Task(BaseModel):
cloned_context = ( cloned_context = (
[task_mapping[context_task.key] for context_task in self.context] [task_mapping[context_task.key] for context_task in self.context]
if self.context if isinstance(self.context, list)
else None else None
) )

View File

@@ -10,6 +10,18 @@ from contextlib import contextmanager
from importlib.metadata import version from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Optional from typing import TYPE_CHECKING, Any, Optional
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter,
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
SpanExportResult,
)
from opentelemetry.trace import Span, Status, StatusCode
from crewai.telemetry.constants import ( from crewai.telemetry.constants import (
CREWAI_TELEMETRY_BASE_URL, CREWAI_TELEMETRY_BASE_URL,
CREWAI_TELEMETRY_SERVICE_NAME, CREWAI_TELEMETRY_SERVICE_NAME,
@@ -25,18 +37,6 @@ def suppress_warnings():
yield yield
from opentelemetry import trace # noqa: E402
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # noqa: E402
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
from opentelemetry.sdk.trace.export import ( # noqa: E402
BatchSpanProcessor,
SpanExportResult,
)
from opentelemetry.trace import Span, Status, StatusCode # noqa: E402
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.crew import Crew from crewai.crew import Crew
from crewai.task import Task from crewai.task import Task
@@ -232,7 +232,7 @@ class Telemetry:
"agent_key": task.agent.key if task.agent else None, "agent_key": task.agent.key if task.agent else None,
"context": ( "context": (
[task.description for task in task.context] [task.description for task in task.context]
if task.context if isinstance(task.context, list)
else None else None
), ),
"tools_names": [ "tools_names": [
@@ -748,7 +748,7 @@ class Telemetry:
"agent_key": task.agent.key if task.agent else None, "agent_key": task.agent.key if task.agent else None,
"context": ( "context": (
[task.description for task in task.context] [task.description for task in task.context]
if task.context if isinstance(task.context, list)
else None else None
), ),
"tools_names": [ "tools_names": [

View File

@@ -5,3 +5,14 @@ KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3 MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255 MAX_FILE_NAME_LENGTH = 255
EMITTER_COLOR = "bold_blue" EMITTER_COLOR = "bold_blue"
class _NotSpecified:
def __repr__(self):
return "NOT_SPECIFIED"
# Sentinel value used to detect when no value has been explicitly provided.
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]".
NOT_SPECIFIED = _NotSpecified()

View File

@@ -1,6 +1,6 @@
import re
from typing import TYPE_CHECKING, List from typing import TYPE_CHECKING, List
if TYPE_CHECKING: if TYPE_CHECKING:
from crewai.task import Task from crewai.task import Task
from crewai.tasks.task_output import TaskOutput from crewai.tasks.task_output import TaskOutput
@@ -17,6 +17,11 @@ def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) ->
def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str: def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str:
"""Generate string context from the tasks.""" """Generate string context from the tasks."""
task_outputs = [task.output for task in tasks if task.output is not None]
task_outputs = (
[task.output for task in tasks if task.output is not None]
if isinstance(tasks, list)
else []
)
return aggregate_raw_outputs_from_task_outputs(task_outputs) return aggregate_raw_outputs_from_task_outputs(task_outputs)

View File

@@ -2,22 +2,18 @@
import hashlib import hashlib
import json import json
import os
import tempfile
from concurrent.futures import Future from concurrent.futures import Future
from unittest import mock from unittest import mock
from unittest.mock import MagicMock, patch from unittest.mock import ANY, MagicMock, patch
import pydantic_core import pydantic_core
import pytest import pytest
from crewai.agent import Agent from crewai.agent import Agent
from crewai.agents import CacheHandler from crewai.agents import CacheHandler
from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.crew import Crew from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput from crewai.crews.crew_output import CrewOutput
from crewai.flow import Flow, listen, start from crewai.flow import Flow, start
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.memory.contextual.contextual_memory import ContextualMemory
@@ -3141,6 +3137,30 @@ def test_replay_with_context():
assert crew.tasks[1].context[0].output.raw == "context raw output" assert crew.tasks[1].context[0].output.raw == "context raw output"
def test_replay_with_context_set_to_nullable():
agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")
task1 = Task(
description="Context Task", expected_output="Say Task Output", agent=agent
)
task2 = Task(
description="Test Task", expected_output="Say Hi", agent=agent, context=[]
)
task3 = Task(
description="Test Task 3", expected_output="Say Hi", agent=agent, context=None
)
crew = Crew(agents=[agent], tasks=[task1, task2, task3], process=Process.sequential)
with patch("crewai.task.Task.execute_sync") as mock_execute_task:
mock_execute_task.return_value = TaskOutput(
description="Test Task Output",
raw="test raw output",
agent="test_agent",
)
crew.kickoff()
mock_execute_task.assert_called_with(agent=ANY, context="", tools=ANY)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
def test_replay_with_invalid_task_id(): def test_replay_with_invalid_task_id():
agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal") agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")