mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
feat: add crew context tracking for LLM guardrail events (#3111)
Add crew context tracking using OpenTelemetry baggage for thread-safe propagation. Context is set during kickoff and cleaned up in finally block. Added thread safety tests with mocked agent execution.
This commit is contained in:
@@ -18,6 +18,11 @@ from typing import (
|
||||
cast,
|
||||
)
|
||||
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
@@ -616,6 +621,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self,
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> CrewOutput:
|
||||
ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
@@ -676,6 +686,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name or "crew"),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
def kickoff_for_each(self, inputs: List[Dict[str, Any]]) -> List[CrewOutput]:
|
||||
"""Executes the Crew's workflow for each input in the list and aggregates results."""
|
||||
|
||||
1
src/crewai/utilities/crew/__init__.py
Normal file
1
src/crewai/utilities/crew/__init__.py
Normal file
@@ -0,0 +1 @@
|
||||
"""Crew-specific utilities."""
|
||||
16
src/crewai/utilities/crew/crew_context.py
Normal file
16
src/crewai/utilities/crew/crew_context.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Context management utilities for tracking crew and task execution context using OpenTelemetry baggage."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from opentelemetry import baggage
|
||||
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def get_crew_context() -> Optional[CrewContext]:
|
||||
"""Get the current crew context from OpenTelemetry baggage.
|
||||
|
||||
Returns:
|
||||
CrewContext instance containing crew context information, or None if no context is set
|
||||
"""
|
||||
return baggage.get_baggage("crew_context")
|
||||
16
src/crewai/utilities/crew/models.py
Normal file
16
src/crewai/utilities/crew/models.py
Normal file
@@ -0,0 +1,16 @@
|
||||
"""Models for crew-related data structures."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
|
||||
class CrewContext(BaseModel):
|
||||
"""Model representing crew context information."""
|
||||
|
||||
id: Optional[str] = Field(
|
||||
default=None, description="Unique identifier for the crew"
|
||||
)
|
||||
key: Optional[str] = Field(
|
||||
default=None, description="Optional crew key/name for identification"
|
||||
)
|
||||
@@ -1,3 +1,4 @@
|
||||
from inspect import getsource
|
||||
from typing import Any, Callable, Optional, Union
|
||||
|
||||
from crewai.utilities.events.base_events import BaseEvent
|
||||
@@ -16,23 +17,26 @@ class LLMGuardrailStartedEvent(BaseEvent):
|
||||
retry_count: int
|
||||
|
||||
def __init__(self, **data):
|
||||
from inspect import getsource
|
||||
|
||||
from crewai.tasks.llm_guardrail import LLMGuardrail
|
||||
from crewai.tasks.hallucination_guardrail import HallucinationGuardrail
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
if isinstance(self.guardrail, LLMGuardrail) or isinstance(
|
||||
self.guardrail, HallucinationGuardrail
|
||||
):
|
||||
if isinstance(self.guardrail, (LLMGuardrail, HallucinationGuardrail)):
|
||||
self.guardrail = self.guardrail.description.strip()
|
||||
elif isinstance(self.guardrail, Callable):
|
||||
self.guardrail = getsource(self.guardrail).strip()
|
||||
|
||||
|
||||
class LLMGuardrailCompletedEvent(BaseEvent):
|
||||
"""Event emitted when a guardrail task completes"""
|
||||
"""Event emitted when a guardrail task completes
|
||||
|
||||
Attributes:
|
||||
success: Whether the guardrail validation passed
|
||||
result: The validation result
|
||||
error: Error message if validation failed
|
||||
retry_count: The number of times the guardrail has been retried
|
||||
"""
|
||||
|
||||
type: str = "llm_guardrail_completed"
|
||||
success: bool
|
||||
|
||||
226
tests/test_crew_thread_safety.py
Normal file
226
tests/test_crew_thread_safety.py
Normal file
@@ -0,0 +1,226 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Dict, Any, Callable
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_agent_factory():
|
||||
def create_agent(name: str) -> Agent:
|
||||
return Agent(
|
||||
role=f"{name} Agent",
|
||||
goal=f"Complete {name} task",
|
||||
backstory=f"I am agent for {name}",
|
||||
)
|
||||
|
||||
return create_agent
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def simple_task_factory():
|
||||
def create_task(name: str, callback: Callable = None) -> Task:
|
||||
return Task(
|
||||
description=f"Task for {name}", expected_output="Done", callback=callback
|
||||
)
|
||||
|
||||
return create_task
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def crew_factory(simple_agent_factory, simple_task_factory):
|
||||
def create_crew(name: str, task_callback: Callable = None) -> Crew:
|
||||
agent = simple_agent_factory(name)
|
||||
task = simple_task_factory(name, callback=task_callback)
|
||||
task.agent = agent
|
||||
|
||||
return Crew(agents=[agent], tasks=[task], verbose=False)
|
||||
|
||||
return create_crew
|
||||
|
||||
|
||||
class TestCrewThreadSafety:
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_parallel_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
|
||||
results = {"crew_id": crew_id, "contexts": []}
|
||||
|
||||
def check_context_task(output):
|
||||
context = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "task_callback",
|
||||
"crew_id": context.id if context else None,
|
||||
"crew_key": context.key if context else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
context_before = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "before_kickoff",
|
||||
"crew_id": context_before.id if context_before else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=check_context_task)
|
||||
output = crew.kickoff()
|
||||
|
||||
context_after = get_crew_context()
|
||||
results["contexts"].append(
|
||||
{
|
||||
"stage": "after_kickoff",
|
||||
"crew_id": context_after.id if context_after else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
|
||||
results["crew_uuid"] = str(crew.id)
|
||||
results["output"] = output.raw
|
||||
|
||||
return results
|
||||
|
||||
with ThreadPoolExecutor(max_workers=num_crews) as executor:
|
||||
futures = []
|
||||
for i in range(num_crews):
|
||||
future = executor.submit(run_crew_with_context_check, f"crew_{i}")
|
||||
futures.append(future)
|
||||
|
||||
results = [f.result() for f in futures]
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
|
||||
before_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
|
||||
)
|
||||
assert (
|
||||
before_ctx["crew_id"] is None
|
||||
), f"Context should be None before kickoff for {result['crew_id']}"
|
||||
|
||||
task_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
|
||||
)
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch during task for {result['crew_id']}"
|
||||
|
||||
after_ctx = next(
|
||||
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
|
||||
)
|
||||
assert (
|
||||
after_ctx["crew_id"] is None
|
||||
), f"Context should be None after kickoff for {result['crew_id']}"
|
||||
|
||||
thread_name = before_ctx["thread"]
|
||||
assert (
|
||||
"ThreadPoolExecutor" in thread_name
|
||||
), f"Should run in thread pool for {result['crew_id']}"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@patch("crewai.Agent.execute_task")
|
||||
async def test_async_crews_thread_safety(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
num_crews = 5
|
||||
|
||||
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
|
||||
task_context = {"crew_id": crew_id, "context": None}
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
task_context["context"] = {
|
||||
"crew_id": ctx.id if ctx else None,
|
||||
"crew_key": ctx.key if ctx else None,
|
||||
}
|
||||
return output
|
||||
|
||||
crew = crew_factory(crew_id, task_callback=capture_context)
|
||||
output = await crew.kickoff_async()
|
||||
|
||||
return {
|
||||
"crew_id": crew_id,
|
||||
"crew_uuid": str(crew.id),
|
||||
"output": output.raw,
|
||||
"task_context": task_context,
|
||||
}
|
||||
|
||||
tasks = [run_crew_async(f"async_crew_{i}") for i in range(num_crews)]
|
||||
results = await asyncio.gather(*tasks)
|
||||
|
||||
for result in results:
|
||||
crew_uuid = result["crew_uuid"]
|
||||
task_ctx = result["task_context"]["context"]
|
||||
|
||||
assert (
|
||||
task_ctx is not None
|
||||
), f"Context should exist during task for {result['crew_id']}"
|
||||
assert (
|
||||
task_ctx["crew_id"] == crew_uuid
|
||||
), f"Context mismatch for {result['crew_id']}"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts_captured = []
|
||||
|
||||
def capture_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts_captured.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"thread": threading.current_thread().name,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
crew = crew_factory("for_each_test", task_callback=capture_context)
|
||||
inputs = [{"item": f"input_{i}"} for i in range(3)]
|
||||
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
assert len(contexts_captured) == len(inputs)
|
||||
|
||||
context_ids = [ctx["context_id"] for ctx in contexts_captured]
|
||||
assert len(set(context_ids)) == len(
|
||||
inputs
|
||||
), "Each execution should have unique context"
|
||||
|
||||
@patch("crewai.Agent.execute_task")
|
||||
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):
|
||||
mock_execute_task.return_value = "Task completed"
|
||||
contexts = []
|
||||
|
||||
def check_context(output):
|
||||
ctx = get_crew_context()
|
||||
contexts.append(
|
||||
{
|
||||
"context_id": ctx.id if ctx else None,
|
||||
"context_key": ctx.key if ctx else None,
|
||||
}
|
||||
)
|
||||
return output
|
||||
|
||||
def run_crew(name: str):
|
||||
crew = crew_factory(name, task_callback=check_context)
|
||||
crew.kickoff()
|
||||
return str(crew.id)
|
||||
|
||||
crew1_id = run_crew("First")
|
||||
crew2_id = run_crew("Second")
|
||||
|
||||
assert len(contexts) == 2
|
||||
assert contexts[0]["context_id"] == crew1_id
|
||||
assert contexts[1]["context_id"] == crew2_id
|
||||
assert contexts[0]["context_id"] != contexts[1]["context_id"]
|
||||
0
tests/utilities/crew/__init__.py
Normal file
0
tests/utilities/crew/__init__.py
Normal file
88
tests/utilities/crew/test_crew_context.py
Normal file
88
tests/utilities/crew/test_crew_context.py
Normal file
@@ -0,0 +1,88 @@
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
from opentelemetry import baggage
|
||||
from opentelemetry.context import attach, detach
|
||||
|
||||
from crewai.utilities.crew.crew_context import get_crew_context
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
|
||||
|
||||
def test_crew_context_creation():
|
||||
crew_id = str(uuid.uuid4())
|
||||
context = CrewContext(id=crew_id, key="test-crew")
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-crew"
|
||||
|
||||
|
||||
def test_get_crew_context_with_baggage():
|
||||
crew_id = str(uuid.uuid4())
|
||||
assert get_crew_context() is None
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test-key")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
try:
|
||||
context = get_crew_context()
|
||||
assert context is not None
|
||||
assert context.id == crew_id
|
||||
assert context.key == "test-key"
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_get_crew_context_empty():
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_nested_contexts():
|
||||
crew_id1 = str(uuid.uuid4())
|
||||
crew_id2 = str(uuid.uuid4())
|
||||
|
||||
crew_ctx1 = CrewContext(id=crew_id1, key="outer")
|
||||
ctx1 = baggage.set_baggage("crew_context", crew_ctx1)
|
||||
token1 = attach(ctx1)
|
||||
|
||||
try:
|
||||
outer_context = get_crew_context()
|
||||
assert outer_context.id == crew_id1
|
||||
assert outer_context.key == "outer"
|
||||
|
||||
crew_ctx2 = CrewContext(id=crew_id2, key="inner")
|
||||
ctx2 = baggage.set_baggage("crew_context", crew_ctx2)
|
||||
token2 = attach(ctx2)
|
||||
|
||||
try:
|
||||
inner_context = get_crew_context()
|
||||
assert inner_context.id == crew_id2
|
||||
assert inner_context.key == "inner"
|
||||
finally:
|
||||
detach(token2)
|
||||
|
||||
restored_context = get_crew_context()
|
||||
assert restored_context.id == crew_id1
|
||||
assert restored_context.key == "outer"
|
||||
finally:
|
||||
detach(token1)
|
||||
|
||||
assert get_crew_context() is None
|
||||
|
||||
|
||||
def test_baggage_exception_handling():
|
||||
crew_id = str(uuid.uuid4())
|
||||
|
||||
crew_ctx = CrewContext(id=crew_id, key="test")
|
||||
ctx = baggage.set_baggage("crew_context", crew_ctx)
|
||||
token = attach(ctx)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
try:
|
||||
assert get_crew_context() is not None
|
||||
raise ValueError("Test exception")
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
assert get_crew_context() is None
|
||||
Reference in New Issue
Block a user