From 34a03f882c12b3bd43de82e09f933bd64a894f49 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Mon, 7 Jul 2025 16:33:07 -0400 Subject: [PATCH] feat: add crew context tracking for LLM guardrail events (#3111) Add crew context tracking using OpenTelemetry baggage for thread-safe propagation. Context is set during kickoff and cleaned up in finally block. Added thread safety tests with mocked agent execution. --- src/crewai/crew.py | 12 + src/crewai/utilities/crew/__init__.py | 1 + src/crewai/utilities/crew/crew_context.py | 16 ++ src/crewai/utilities/crew/models.py | 16 ++ .../utilities/events/llm_guardrail_events.py | 16 +- tests/test_crew_thread_safety.py | 226 ++++++++++++++++++ tests/utilities/crew/__init__.py | 0 tests/utilities/crew/test_crew_context.py | 88 +++++++ 8 files changed, 369 insertions(+), 6 deletions(-) create mode 100644 src/crewai/utilities/crew/__init__.py create mode 100644 src/crewai/utilities/crew/crew_context.py create mode 100644 src/crewai/utilities/crew/models.py create mode 100644 tests/test_crew_thread_safety.py create mode 100644 tests/utilities/crew/__init__.py create mode 100644 tests/utilities/crew/test_crew_context.py diff --git a/src/crewai/crew.py b/src/crewai/crew.py index d4b14e833..e1b7dc04f 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -18,6 +18,11 @@ from typing import ( cast, ) +from opentelemetry import baggage +from opentelemetry.context import attach, detach + +from crewai.utilities.crew.models import CrewContext + from pydantic import ( UUID4, BaseModel, @@ -616,6 +621,11 @@ class Crew(FlowTrackable, BaseModel): self, inputs: Optional[Dict[str, Any]] = None, ) -> CrewOutput: + ctx = baggage.set_baggage( + "crew_context", CrewContext(id=str(self.id), key=self.key) + ) + token = attach(ctx) + try: for before_callback in self.before_kickoff_callbacks: if inputs is None: @@ -676,6 +686,8 @@ class Crew(FlowTrackable, BaseModel): CrewKickoffFailedEvent(error=str(e), crew_name=self.name or "crew"), ) raise + finally: + detach(token) def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]: """Executes the Crew's workflow for each input in the list and aggregates results.""" diff --git a/src/crewai/utilities/crew/__init__.py b/src/crewai/utilities/crew/__init__.py new file mode 100644 index 000000000..db74f269b --- /dev/null +++ b/src/crewai/utilities/crew/__init__.py @@ -0,0 +1 @@ +"""Crew-specific utilities.""" \ No newline at end of file diff --git a/src/crewai/utilities/crew/crew_context.py b/src/crewai/utilities/crew/crew_context.py new file mode 100644 index 000000000..3f287b566 --- /dev/null +++ b/src/crewai/utilities/crew/crew_context.py @@ -0,0 +1,16 @@ +"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage.""" + +from typing import Optional + +from opentelemetry import baggage + +from crewai.utilities.crew.models import CrewContext + + +def get_crew_context() -> Optional[CrewContext]: + """Get the current crew context from OpenTelemetry baggage. + + Returns: + CrewContext instance containing crew context information, or None if no context is set + """ + return baggage.get_baggage("crew_context") diff --git a/src/crewai/utilities/crew/models.py b/src/crewai/utilities/crew/models.py new file mode 100644 index 000000000..78a1f33a6 --- /dev/null +++ b/src/crewai/utilities/crew/models.py @@ -0,0 +1,16 @@ +"""Models for crew-related data structures.""" + +from typing import Optional + +from pydantic import BaseModel, Field + + +class CrewContext(BaseModel): + """Model representing crew context information.""" + + id: Optional[str] = Field( + default=None, description="Unique identifier for the crew" + ) + key: Optional[str] = Field( + default=None, description="Optional crew key/name for identification" + ) diff --git a/src/crewai/utilities/events/llm_guardrail_events.py b/src/crewai/utilities/events/llm_guardrail_events.py index 01831e12c..d60e226f4 100644 --- a/src/crewai/utilities/events/llm_guardrail_events.py +++ b/src/crewai/utilities/events/llm_guardrail_events.py @@ -1,3 +1,4 @@ +from inspect import getsource from typing import Any, Callable, Optional, Union from crewai.utilities.events.base_events import BaseEvent @@ -16,23 +17,26 @@ class LLMGuardrailStartedEvent(BaseEvent): retry_count: int def __init__(self, **data): - from inspect import getsource - from crewai.tasks.llm_guardrail import LLMGuardrail from crewai.tasks.hallucination_guardrail import HallucinationGuardrail super().__init__(**data) - if isinstance(self.guardrail, LLMGuardrail) or isinstance( - self.guardrail, HallucinationGuardrail - ): + if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)): self.guardrail = self.guardrail.description.strip() elif isinstance(self.guardrail, Callable): self.guardrail = getsource(self.guardrail).strip() class LLMGuardrailCompletedEvent(BaseEvent): - """Event emitted when a guardrail task completes""" + """Event emitted when a guardrail task completes + + Attributes: + success: Whether the guardrail validation passed + result: The validation result + error: Error message if validation failed + retry_count: The number of times the guardrail has been retried + """ type: str = "llm_guardrail_completed" success: bool diff --git a/tests/test_crew_thread_safety.py b/tests/test_crew_thread_safety.py new file mode 100644 index 000000000..145a0405c --- /dev/null +++ b/tests/test_crew_thread_safety.py @@ -0,0 +1,226 @@ +import asyncio +import threading +from concurrent.futures import ThreadPoolExecutor +from typing import Dict, Any, Callable +from unittest.mock import patch + +import pytest + +from crewai import Agent, Crew, Task +from crewai.utilities.crew.crew_context import get_crew_context + + +@pytest.fixture +def simple_agent_factory(): + def create_agent(name: str) -> Agent: + return Agent( + role=f"{name} Agent", + goal=f"Complete {name} task", + backstory=f"I am agent for {name}", + ) + + return create_agent + + +@pytest.fixture +def simple_task_factory(): + def create_task(name: str, callback: Callable = None) -> Task: + return Task( + description=f"Task for {name}", expected_output="Done", callback=callback + ) + + return create_task + + +@pytest.fixture +def crew_factory(simple_agent_factory, simple_task_factory): + def create_crew(name: str, task_callback: Callable = None) -> Crew: + agent = simple_agent_factory(name) + task = simple_task_factory(name, callback=task_callback) + task.agent = agent + + return Crew(agents=[agent], tasks=[task], verbose=False) + + return create_crew + + +class TestCrewThreadSafety: + @patch("crewai.Agent.execute_task") + def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory): + mock_execute_task.return_value = "Task completed" + num_crews = 5 + + def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]: + results = {"crew_id": crew_id, "contexts": []} + + def check_context_task(output): + context = get_crew_context() + results["contexts"].append( + { + "stage": "task_callback", + "crew_id": context.id if context else None, + "crew_key": context.key if context else None, + "thread": threading.current_thread().name, + } + ) + return output + + context_before = get_crew_context() + results["contexts"].append( + { + "stage": "before_kickoff", + "crew_id": context_before.id if context_before else None, + "thread": threading.current_thread().name, + } + ) + + crew = crew_factory(crew_id, task_callback=check_context_task) + output = crew.kickoff() + + context_after = get_crew_context() + results["contexts"].append( + { + "stage": "after_kickoff", + "crew_id": context_after.id if context_after else None, + "thread": threading.current_thread().name, + } + ) + + results["crew_uuid"] = str(crew.id) + results["output"] = output.raw + + return results + + with ThreadPoolExecutor(max_workers=num_crews) as executor: + futures = [] + for i in range(num_crews): + future = executor.submit(run_crew_with_context_check, f"crew_{i}") + futures.append(future) + + results = [f.result() for f in futures] + + for result in results: + crew_uuid = result["crew_uuid"] + + before_ctx = next( + ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff" + ) + assert ( + before_ctx["crew_id"] is None + ), f"Context should be None before kickoff for {result['crew_id']}" + + task_ctx = next( + ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback" + ) + assert ( + task_ctx["crew_id"] == crew_uuid + ), f"Context mismatch during task for {result['crew_id']}" + + after_ctx = next( + ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff" + ) + assert ( + after_ctx["crew_id"] is None + ), f"Context should be None after kickoff for {result['crew_id']}" + + thread_name = before_ctx["thread"] + assert ( + "ThreadPoolExecutor" in thread_name + ), f"Should run in thread pool for {result['crew_id']}" + + @pytest.mark.asyncio + @patch("crewai.Agent.execute_task") + async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory): + mock_execute_task.return_value = "Task completed" + num_crews = 5 + + async def run_crew_async(crew_id: str) -> Dict[str, Any]: + task_context = {"crew_id": crew_id, "context": None} + + def capture_context(output): + ctx = get_crew_context() + task_context["context"] = { + "crew_id": ctx.id if ctx else None, + "crew_key": ctx.key if ctx else None, + } + return output + + crew = crew_factory(crew_id, task_callback=capture_context) + output = await crew.kickoff_async() + + return { + "crew_id": crew_id, + "crew_uuid": str(crew.id), + "output": output.raw, + "task_context": task_context, + } + + tasks = [run_crew_async(f"async_crew_{i}") for i in range(num_crews)] + results = await asyncio.gather(*tasks) + + for result in results: + crew_uuid = result["crew_uuid"] + task_ctx = result["task_context"]["context"] + + assert ( + task_ctx is not None + ), f"Context should exist during task for {result['crew_id']}" + assert ( + task_ctx["crew_id"] == crew_uuid + ), f"Context mismatch for {result['crew_id']}" + + @patch("crewai.Agent.execute_task") + def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory): + mock_execute_task.return_value = "Task completed" + contexts_captured = [] + + def capture_context(output): + ctx = get_crew_context() + contexts_captured.append( + { + "context_id": ctx.id if ctx else None, + "thread": threading.current_thread().name, + } + ) + return output + + crew = crew_factory("for_each_test", task_callback=capture_context) + inputs = [{"item": f"input_{i}"} for i in range(3)] + + results = crew.kickoff_for_each(inputs=inputs) + + assert len(results) == len(inputs) + assert len(contexts_captured) == len(inputs) + + context_ids = [ctx["context_id"] for ctx in contexts_captured] + assert len(set(context_ids)) == len( + inputs + ), "Each execution should have unique context" + + @patch("crewai.Agent.execute_task") + def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory): + mock_execute_task.return_value = "Task completed" + contexts = [] + + def check_context(output): + ctx = get_crew_context() + contexts.append( + { + "context_id": ctx.id if ctx else None, + "context_key": ctx.key if ctx else None, + } + ) + return output + + def run_crew(name: str): + crew = crew_factory(name, task_callback=check_context) + crew.kickoff() + return str(crew.id) + + crew1_id = run_crew("First") + crew2_id = run_crew("Second") + + assert len(contexts) == 2 + assert contexts[0]["context_id"] == crew1_id + assert contexts[1]["context_id"] == crew2_id + assert contexts[0]["context_id"] != contexts[1]["context_id"] diff --git a/tests/utilities/crew/__init__.py b/tests/utilities/crew/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/utilities/crew/test_crew_context.py b/tests/utilities/crew/test_crew_context.py new file mode 100644 index 000000000..29ce5a356 --- /dev/null +++ b/tests/utilities/crew/test_crew_context.py @@ -0,0 +1,88 @@ +import uuid + +import pytest +from opentelemetry import baggage +from opentelemetry.context import attach, detach + +from crewai.utilities.crew.crew_context import get_crew_context +from crewai.utilities.crew.models import CrewContext + + +def test_crew_context_creation(): + crew_id = str(uuid.uuid4()) + context = CrewContext(id=crew_id, key="test-crew") + assert context.id == crew_id + assert context.key == "test-crew" + + +def test_get_crew_context_with_baggage(): + crew_id = str(uuid.uuid4()) + assert get_crew_context() is None + + crew_ctx = CrewContext(id=crew_id, key="test-key") + ctx = baggage.set_baggage("crew_context", crew_ctx) + token = attach(ctx) + + try: + context = get_crew_context() + assert context is not None + assert context.id == crew_id + assert context.key == "test-key" + finally: + detach(token) + + assert get_crew_context() is None + + +def test_get_crew_context_empty(): + assert get_crew_context() is None + + +def test_baggage_nested_contexts(): + crew_id1 = str(uuid.uuid4()) + crew_id2 = str(uuid.uuid4()) + + crew_ctx1 = CrewContext(id=crew_id1, key="outer") + ctx1 = baggage.set_baggage("crew_context", crew_ctx1) + token1 = attach(ctx1) + + try: + outer_context = get_crew_context() + assert outer_context.id == crew_id1 + assert outer_context.key == "outer" + + crew_ctx2 = CrewContext(id=crew_id2, key="inner") + ctx2 = baggage.set_baggage("crew_context", crew_ctx2) + token2 = attach(ctx2) + + try: + inner_context = get_crew_context() + assert inner_context.id == crew_id2 + assert inner_context.key == "inner" + finally: + detach(token2) + + restored_context = get_crew_context() + assert restored_context.id == crew_id1 + assert restored_context.key == "outer" + finally: + detach(token1) + + assert get_crew_context() is None + + +def test_baggage_exception_handling(): + crew_id = str(uuid.uuid4()) + + crew_ctx = CrewContext(id=crew_id, key="test") + ctx = baggage.set_baggage("crew_context", crew_ctx) + token = attach(ctx) + + with pytest.raises(ValueError): + try: + assert get_crew_context() is not None + raise ValueError("Test exception") + finally: + detach(token) + + assert get_crew_context() is None