Compare commits

...

4 Commits

Author SHA1 Message Date
Lucas Gomide
ebcda8b2ec Merge branch 'main' into lg-support-set-task-context 2025-05-13 18:13:35 -03:00
Lucas Gomide
fced8ba47f sytle: fix linter issues 2025-05-12 11:53:57 -03:00
Lucas Gomide
7204910da4 Merge branch 'main' into lg-support-set-task-context 2025-05-12 09:47:52 -03:00
Lucas Gomide
971a90f534 feat: support to set an empty context to the Task 2025-05-10 09:46:21 -03:00
6 changed files with 72 additions and 33 deletions

View File

@@ -52,7 +52,7 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool, Tool
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities import I18N, FileHandler, Logger, RPMController
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.evaluators.crew_evaluator_handler import CrewEvaluator
from crewai.utilities.evaluators.task_evaluator import TaskEvaluator
from crewai.utilities.events.crew_events import (
@@ -478,7 +478,7 @@ class Crew(FlowTrackable, BaseModel):
separated by a synchronous task.
"""
for i, task in enumerate(self.tasks):
if task.async_execution and task.context:
if task.async_execution and isinstance(task.context, list):
for context_task in task.context:
if context_task.async_execution:
for j in range(i - 1, -1, -1):
@@ -496,7 +496,7 @@ class Crew(FlowTrackable, BaseModel):
task_indices = {id(task): i for i, task in enumerate(self.tasks)}
for task in self.tasks:
if task.context:
if isinstance(task.context, list):
for context_task in task.context:
if id(context_task) not in task_indices:
continue # Skip context tasks not in the main tasks list
@@ -1034,11 +1034,14 @@ class Crew(FlowTrackable, BaseModel):
)
return cast(List[BaseTool], tools)
def _get_context(self, task: Task, task_outputs: List[TaskOutput]):
def _get_context(self, task: Task, task_outputs: List[TaskOutput]) -> str:
if not task.context:
return ""
context = (
aggregate_raw_outputs_from_tasks(task.context)
if task.context
else aggregate_raw_outputs_from_task_outputs(task_outputs)
aggregate_raw_outputs_from_task_outputs(task_outputs)
if task.context is NOT_SPECIFIED
else aggregate_raw_outputs_from_tasks(task.context)
)
return context
@@ -1226,7 +1229,7 @@ class Crew(FlowTrackable, BaseModel):
task_mapping[task.key] = cloned_task
for cloned_task, original_task in zip(cloned_tasks, self.tasks):
if original_task.context:
if isinstance(original_task.context, list):
cloned_context = [
task_mapping[context_task.key]
for context_task in original_task.context

View File

@@ -2,7 +2,6 @@ import datetime
import inspect
import json
import logging
import re
import threading
import uuid
from concurrent.futures import Future
@@ -41,6 +40,7 @@ from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.tools.base_tool import BaseTool
from crewai.utilities.config import process_config
from crewai.utilities.constants import NOT_SPECIFIED
from crewai.utilities.converter import Converter, convert_to_model
from crewai.utilities.events import (
TaskCompletedEvent,
@@ -97,7 +97,7 @@ class Task(BaseModel):
)
context: Optional[List["Task"]] = Field(
description="Other tasks that will have their output used as context for this task.",
default=None,
default=NOT_SPECIFIED,
)
async_execution: Optional[bool] = Field(
description="Whether the task should be executed asynchronously or not.",
@@ -643,7 +643,7 @@ class Task(BaseModel):
cloned_context = (
[task_mapping[context_task.key] for context_task in self.context]
if self.context
if isinstance(self.context, list)
else None
)

View File

@@ -10,6 +10,18 @@ from contextlib import contextmanager
from importlib.metadata import version
from typing import TYPE_CHECKING, Any, Optional
from opentelemetry import trace
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter,
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource
from opentelemetry.sdk.trace import TracerProvider
from opentelemetry.sdk.trace.export import (
BatchSpanProcessor,
SpanExportResult,
)
from opentelemetry.trace import Span, Status, StatusCode
from crewai.telemetry.constants import (
CREWAI_TELEMETRY_BASE_URL,
CREWAI_TELEMETRY_SERVICE_NAME,
@@ -25,18 +37,6 @@ def suppress_warnings():
yield
from opentelemetry import trace # noqa: E402
from opentelemetry.exporter.otlp.proto.http.trace_exporter import (
OTLPSpanExporter, # noqa: E402
)
from opentelemetry.sdk.resources import SERVICE_NAME, Resource # noqa: E402
from opentelemetry.sdk.trace import TracerProvider # noqa: E402
from opentelemetry.sdk.trace.export import ( # noqa: E402
BatchSpanProcessor,
SpanExportResult,
)
from opentelemetry.trace import Span, Status, StatusCode # noqa: E402
if TYPE_CHECKING:
from crewai.crew import Crew
from crewai.task import Task
@@ -232,7 +232,7 @@ class Telemetry:
"agent_key": task.agent.key if task.agent else None,
"context": (
[task.description for task in task.context]
if task.context
if isinstance(task.context, list)
else None
),
"tools_names": [
@@ -748,7 +748,7 @@ class Telemetry:
"agent_key": task.agent.key if task.agent else None,
"context": (
[task.description for task in task.context]
if task.context
if isinstance(task.context, list)
else None
),
"tools_names": [

View File

@@ -5,3 +5,14 @@ KNOWLEDGE_DIRECTORY = "knowledge"
MAX_LLM_RETRY = 3
MAX_FILE_NAME_LENGTH = 255
EMITTER_COLOR = "bold_blue"
class _NotSpecified:
def __repr__(self):
return "NOT_SPECIFIED"
# Sentinel value used to detect when no value has been explicitly provided.
# Unlike `None`, which might be a valid value from the user, `NOT_SPECIFIED` allows
# us to distinguish between "not passed at all" and "explicitly passed None" or "[]".
NOT_SPECIFIED = _NotSpecified()

View File

@@ -1,6 +1,6 @@
import re
from typing import TYPE_CHECKING, List
if TYPE_CHECKING:
from crewai.task import Task
from crewai.tasks.task_output import TaskOutput
@@ -17,6 +17,11 @@ def aggregate_raw_outputs_from_task_outputs(task_outputs: List["TaskOutput"]) ->
def aggregate_raw_outputs_from_tasks(tasks: List["Task"]) -> str:
"""Generate string context from the tasks."""
task_outputs = [task.output for task in tasks if task.output is not None]
task_outputs = (
[task.output for task in tasks if task.output is not None]
if isinstance(tasks, list)
else []
)
return aggregate_raw_outputs_from_task_outputs(task_outputs)

View File

@@ -2,22 +2,18 @@
import hashlib
import json
import os
import tempfile
from concurrent.futures import Future
from unittest import mock
from unittest.mock import MagicMock, patch
from unittest.mock import ANY, MagicMock, patch
import pydantic_core
import pytest
from crewai.agent import Agent
from crewai.agents import CacheHandler
from crewai.agents.cache import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.crew import Crew
from crewai.crews.crew_output import CrewOutput
from crewai.flow import Flow, listen, start
from crewai.flow import Flow, start
from crewai.knowledge.source.string_knowledge_source import StringKnowledgeSource
from crewai.llm import LLM
from crewai.memory.contextual.contextual_memory import ContextualMemory
@@ -3141,6 +3137,30 @@ def test_replay_with_context():
assert crew.tasks[1].context[0].output.raw == "context raw output"
def test_replay_with_context_set_to_nullable():
agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")
task1 = Task(
description="Context Task", expected_output="Say Task Output", agent=agent
)
task2 = Task(
description="Test Task", expected_output="Say Hi", agent=agent, context=[]
)
task3 = Task(
description="Test Task 3", expected_output="Say Hi", agent=agent, context=None
)
crew = Crew(agents=[agent], tasks=[task1, task2, task3], process=Process.sequential)
with patch("crewai.task.Task.execute_sync") as mock_execute_task:
mock_execute_task.return_value = TaskOutput(
description="Test Task Output",
raw="test raw output",
agent="test_agent",
)
crew.kickoff()
mock_execute_task.assert_called_with(agent=ANY, context="", tools=ANY)
@pytest.mark.vcr(filter_headers=["authorization"])
def test_replay_with_invalid_task_id():
agent = Agent(role="test_agent", backstory="Test Description", goal="Test Goal")