Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
1ae3a003b6 Fix lint errors in test_custom_llm.py
- Add noqa comment for hardcoded test JWT token
- Add return statement to satisfy ruff RET503 check

Co-Authored-By: João <joao@crewai.com>
2025-10-15 03:01:54 +00:00
Devin AI
fc4b0dd923 Fix function_calling_llm support for custom models
- Add supports_function_calling() method to BaseLLM class with default True
- Add supports_function_calling parameter to LLM class to allow override of litellm check
- Add tests for both BaseLLM default and LLM override functionality
- Fixes #3708: Custom models not in litellm's list can now use function calling

Co-Authored-By: João <joao@crewai.com>
2025-10-15 02:57:03 +00:00
8 changed files with 169 additions and 566 deletions

View File

@@ -864,7 +864,6 @@ class Agent(BaseAgent):
i18n=self.i18n,
original_agent=self,
guardrail=self.guardrail,
guardrail_max_retries=self.guardrail_max_retries,
)
return await lite_agent.kickoff_async(messages)

View File

@@ -299,6 +299,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
supports_function_calling: bool | None = None,
**kwargs,
):
self.model = model
@@ -325,6 +326,7 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self._supports_function_calling_override = supports_function_calling
litellm.drop_params = True
@@ -1197,6 +1199,9 @@ class LLM(BaseLLM):
)
def supports_function_calling(self) -> bool:
if self._supports_function_calling_override is not None:
return self._supports_function_calling_override
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(

View File

@@ -9,6 +9,7 @@ from typing import Any, Final
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
DEFAULT_SUPPORTS_FUNCTION_CALLING: Final[bool] = True
class BaseLLM(ABC):
@@ -82,6 +83,14 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return DEFAULT_SUPPORTS_FUNCTION_CALLING
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.

View File

@@ -5,7 +5,7 @@ import logging
import threading
import uuid
import warnings
from collections.abc import Callable, Sequence
from collections.abc import Callable
from concurrent.futures import Future
from copy import copy as shallow_copy
from hashlib import md5
@@ -152,15 +152,6 @@ class Task(BaseModel):
default=None,
description="Function or string description of a guardrail to validate task output before proceeding to next task",
)
guardrails: (
Sequence[Callable[[TaskOutput], tuple[bool, Any]] | str]
| Callable[[TaskOutput], tuple[bool, Any]]
| str
| None
) = Field(
default=None,
description="List of guardrails to validate task output before proceeding to next task. Also supports a single guardrail function or string description of a guardrail to validate task output before proceeding to next task",
)
max_retries: int | None = Field(
default=None,
description="[DEPRECATED] Maximum number of retries when guardrail fails. Use guardrail_max_retries instead. Will be removed in v1.0.0",
@@ -277,44 +268,6 @@ class Task(BaseModel):
return self
@model_validator(mode="after")
def ensure_guardrails_is_list_of_callables(self) -> "Task":
guardrails = []
if self.guardrails is not None and (
not isinstance(self.guardrails, (list, tuple)) or len(self.guardrails) > 0
):
if self.agent is None:
raise ValueError("Agent is required to use guardrails")
if callable(self.guardrails):
guardrails.append(self.guardrails)
elif isinstance(self.guardrails, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=self.guardrails, llm=self.agent.llm)
)
if isinstance(self.guardrails, list):
for guardrail in self.guardrails:
if callable(guardrail):
guardrails.append(guardrail)
elif isinstance(guardrail, str):
from crewai.tasks.llm_guardrail import LLMGuardrail
guardrails.append(
LLMGuardrail(description=guardrail, llm=self.agent.llm)
)
else:
raise ValueError("Guardrail must be a callable or a string")
self._guardrails = guardrails
if self._guardrails:
self.guardrail = None
self._guardrail = None
return self
@field_validator("id", mode="before")
@classmethod
def _deny_user_set_id(cls, v: UUID4 | None) -> None:
@@ -503,23 +456,48 @@ class Task(BaseModel):
output_format=self._get_output_format(),
)
if self._guardrails:
for guardrail in self._guardrails:
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
if self._guardrail:
guardrail_result = process_guardrail(
output=task_output,
guardrail=self._guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
# backwards support
if self._guardrail:
task_output = self._invoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -811,55 +789,3 @@ Follow these guidelines:
Fingerprint: The fingerprint of the task
"""
return self.security_config.fingerprint
def _invoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: Callable | None,
) -> TaskOutput:
if guardrail:
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=self.retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if not guardrail_result.success:
if self.retry_count >= self.guardrail_max_retries:
raise Exception(
f"Task failed guardrail validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail blocked, retrying, due to: {guardrail_result.error}\n",
color="yellow",
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output

View File

@@ -1,8 +1,7 @@
import asyncio
import threading
from collections.abc import Callable
from concurrent.futures import ThreadPoolExecutor
from typing import Any
from typing import Dict, Any, Callable
from unittest.mock import patch
import pytest
@@ -25,12 +24,9 @@ def simple_agent_factory():
@pytest.fixture
def simple_task_factory():
def create_task(name: str, agent: Agent, callback: Callable | None = None) -> Task:
def create_task(name: str, callback: Callable = None) -> Task:
return Task(
description=f"Task for {name}",
expected_output="Done",
agent=agent,
callback=callback,
description=f"Task for {name}", expected_output="Done", callback=callback
)
return create_task
@@ -38,9 +34,10 @@ def simple_task_factory():
@pytest.fixture
def crew_factory(simple_agent_factory, simple_task_factory):
def create_crew(name: str, task_callback: Callable | None = None) -> Crew:
def create_crew(name: str, task_callback: Callable = None) -> Crew:
agent = simple_agent_factory(name)
task = simple_task_factory(name, agent=agent, callback=task_callback)
task = simple_task_factory(name, callback=task_callback)
task.agent = agent
return Crew(agents=[agent], tasks=[task], verbose=False)
@@ -53,7 +50,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
def run_crew_with_context_check(crew_id: str) -> dict[str, Any]:
def run_crew_with_context_check(crew_id: str) -> Dict[str, Any]:
results = {"crew_id": crew_id, "contexts": []}
def check_context_task(output):
@@ -108,28 +105,28 @@ class TestCrewThreadSafety:
before_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "before_kickoff"
)
assert before_ctx["crew_id"] is None, (
f"Context should be None before kickoff for {result['crew_id']}"
)
assert (
before_ctx["crew_id"] is None
), f"Context should be None before kickoff for {result['crew_id']}"
task_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "task_callback"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch during task for {result['crew_id']}"
)
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch during task for {result['crew_id']}"
after_ctx = next(
ctx for ctx in result["contexts"] if ctx["stage"] == "after_kickoff"
)
assert after_ctx["crew_id"] is None, (
f"Context should be None after kickoff for {result['crew_id']}"
)
assert (
after_ctx["crew_id"] is None
), f"Context should be None after kickoff for {result['crew_id']}"
thread_name = before_ctx["thread"]
assert "ThreadPoolExecutor" in thread_name, (
f"Should run in thread pool for {result['crew_id']}"
)
assert (
"ThreadPoolExecutor" in thread_name
), f"Should run in thread pool for {result['crew_id']}"
@pytest.mark.asyncio
@patch("crewai.Agent.execute_task")
@@ -137,7 +134,7 @@ class TestCrewThreadSafety:
mock_execute_task.return_value = "Task completed"
num_crews = 5
async def run_crew_async(crew_id: str) -> dict[str, Any]:
async def run_crew_async(crew_id: str) -> Dict[str, Any]:
task_context = {"crew_id": crew_id, "context": None}
def capture_context(output):
@@ -165,12 +162,12 @@ class TestCrewThreadSafety:
crew_uuid = result["crew_uuid"]
task_ctx = result["task_context"]["context"]
assert task_ctx is not None, (
f"Context should exist during task for {result['crew_id']}"
)
assert task_ctx["crew_id"] == crew_uuid, (
f"Context mismatch for {result['crew_id']}"
)
assert (
task_ctx is not None
), f"Context should exist during task for {result['crew_id']}"
assert (
task_ctx["crew_id"] == crew_uuid
), f"Context mismatch for {result['crew_id']}"
@patch("crewai.Agent.execute_task")
def test_concurrent_kickoff_for_each(self, mock_execute_task, crew_factory):
@@ -196,9 +193,9 @@ class TestCrewThreadSafety:
assert len(contexts_captured) == len(inputs)
context_ids = [ctx["context_id"] for ctx in contexts_captured]
assert len(set(context_ids)) == len(inputs), (
"Each execution should have unique context"
)
assert len(set(context_ids)) == len(
inputs
), "Each execution should have unique context"
@patch("crewai.Agent.execute_task")
def test_no_context_leakage_between_crews(self, mock_execute_task, crew_factory):

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any
import pytest
@@ -159,11 +159,11 @@ class JWTAuthLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Record the call and return a predefined response."""
self.calls.append(
{
@@ -192,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") # noqa: S106
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
@@ -238,11 +238,11 @@ class TimeoutHandlingLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Simulate API calls with timeout handling and retry logic.
Args:
@@ -282,35 +282,34 @@ class TimeoutHandlingLLM(BaseLLM):
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Success on first attempt
return "First attempt response"
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on retry
return "Response after retry"
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
@@ -358,3 +357,25 @@ def test_timeout_handling_llm():
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt
class MinimalCustomLLM(BaseLLM):
"""Minimal custom LLM implementation that doesn't override supports_function_calling."""
def __init__(self):
super().__init__(model="minimal-model")
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
return "Minimal response"
def test_base_llm_supports_function_calling_default():
"""Test that BaseLLM supports function calling by default."""
llm = MinimalCustomLLM()
assert llm.supports_function_calling() is True

View File

@@ -711,3 +711,18 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
formatted = ollama_llm._format_messages_for_provider(original_messages)
assert formatted == original_messages
def test_supports_function_calling_with_override_true():
llm = LLM(model="custom-model/my-model", supports_function_calling=True)
assert llm.supports_function_calling() is True
def test_supports_function_calling_with_override_false():
llm = LLM(model="gpt-4o-mini", supports_function_calling=False)
assert llm.supports_function_calling() is False
def test_supports_function_calling_without_override():
llm = LLM(model="gpt-4o-mini")
assert llm.supports_function_calling() is True

View File

@@ -14,24 +14,6 @@ from crewai.tasks.llm_guardrail import LLMGuardrail
from crewai.tasks.task_output import TaskOutput
def create_smart_task(**kwargs):
"""
Smart task factory that automatically assigns a mock agent when guardrails are present.
This maintains backward compatibility while handling the agent requirement for guardrails.
"""
guardrails_list = kwargs.get("guardrails")
has_guardrails = kwargs.get("guardrail") is not None or (
guardrails_list is not None and len(guardrails_list) > 0
)
if has_guardrails and kwargs.get("agent") is None:
kwargs["agent"] = Agent(
role="test_agent", goal="test_goal", backstory="test_backstory"
)
return Task(**kwargs)
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
@@ -39,7 +21,7 @@ def test_task_without_guardrail():
agent.execute_task.return_value = "test result"
agent.crew = None
task = create_smart_task(description="Test task", expected_output="Output")
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -57,9 +39,7 @@ def test_task_with_successful_guardrail_func():
agent.execute_task.return_value = "test result"
agent.crew = None
task = create_smart_task(
description="Test task", expected_output="Output", guardrail=guardrail
)
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
@@ -77,7 +57,7 @@ def test_task_with_failing_guardrail():
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -104,7 +84,7 @@ def test_task_with_guardrail_retries():
agent.execute_task.return_value = "bad result"
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -129,7 +109,7 @@ def test_guardrail_error_in_context():
agent.role = "test_agent"
agent.crew = None
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
@@ -197,7 +177,7 @@ def test_guardrail_emits_events(sample_agent):
started_guardrail = []
completed_guardrail = []
task = create_smart_task(
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -230,7 +210,7 @@ def test_guardrail_emits_events(sample_agent):
def custom_guardrail(result: TaskOutput):
return (True, "good result from callable function")
task = create_smart_task(
task = Task(
description="Test task",
expected_output="Output",
guardrail=custom_guardrail,
@@ -282,7 +262,7 @@ def test_guardrail_when_an_error_occurs(sample_agent, task_output):
match="Error while validating the task output: Unexpected error",
),
):
task = create_smart_task(
task = Task(
description="Gather information about available books on the First World War",
agent=sample_agent,
expected_output="A list of available books on the First World War",
@@ -304,7 +284,7 @@ def test_hallucination_guardrail_integration():
context="Test reference context for validation", llm=mock_llm, threshold=8.0
)
task = create_smart_task(
task = Task(
description="Test task with hallucination guardrail",
expected_output="Valid output",
guardrail=guardrail,
@@ -324,352 +304,3 @@ def test_hallucination_guardrail_description_in_events():
event = LLMGuardrailStartedEvent(guardrail=guardrail, retry_count=0)
assert event.guardrail == "HallucinationGuardrail (no-op)"
def test_multiple_guardrails_sequential_processing():
"""Test that multiple guardrails are processed sequentially."""
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""First guardrail adds prefix."""
return (True, f"[FIRST] {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Second guardrail adds suffix."""
return (True, f"{result.raw} [SECOND]")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Third guardrail converts to uppercase."""
return (True, result.raw.upper())
agent = Mock()
agent.role = "sequential_agent"
agent.execute_task.return_value = "original text"
agent.crew = None
task = create_smart_task(
description="Test sequential guardrails",
expected_output="Processed text",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "[FIRST] ORIGINAL TEXT [SECOND]"
def test_multiple_guardrails_with_validation_failure():
"""Test multiple guardrails where one fails validation."""
def length_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Ensure minimum length."""
if len(result.raw) < 10:
return (False, "Text too short")
return (True, result.raw)
def format_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Add formatting only if not already formatted."""
if not result.raw.startswith("Formatted:"):
return (True, f"Formatted: {result.raw}")
return (True, result.raw)
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final validation."""
if "Formatted:" not in result.raw:
return (False, "Missing formatting")
return (True, result.raw)
# Use a callable that tracks calls and returns appropriate values
call_count = 0
def mock_execute_task(*args, **kwargs):
nonlocal call_count
call_count += 1
result = (
"short"
if call_count == 1
else "this is a longer text that meets requirements"
)
return result
agent = Mock()
agent.role = "validation_agent"
agent.execute_task = mock_execute_task
agent.crew = None
task = create_smart_task(
description="Test guardrails with validation",
expected_output="Valid formatted text",
guardrails=[length_guardrail, format_guardrail, validation_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
# The second call should be processed through all guardrails
assert result.raw == "Formatted: this is a longer text that meets requirements"
assert task.retry_count == 1
def test_multiple_guardrails_with_mixed_string_and_taskoutput():
"""Test guardrails that return both strings and TaskOutput objects."""
def string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Returns a string."""
return (True, f"String: {result.raw}")
def taskoutput_guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
"""Returns a TaskOutput object."""
new_output = TaskOutput(
name=result.name,
description=result.description,
expected_output=result.expected_output,
raw=f"TaskOutput: {result.raw}",
agent=result.agent,
output_format=result.output_format,
)
return (True, new_output)
def final_string_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Final string transformation."""
return (True, f"Final: {result.raw}")
agent = Mock()
agent.role = "mixed_agent"
agent.execute_task.return_value = "original"
agent.crew = None
task = create_smart_task(
description="Test mixed return types",
expected_output="Mixed processing",
guardrails=[string_guardrail, taskoutput_guardrail, final_string_guardrail],
)
result = task.execute_sync(agent=agent)
assert result.raw == "Final: TaskOutput: String: original"
def test_multiple_guardrails_with_retry_on_middle_guardrail():
"""Test that retry works correctly when a middle guardrail fails."""
call_count = {"first": 0, "second": 0, "third": 0}
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["first"] += 1
return (True, f"First({call_count['first']}): {result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Fails on first attempt, succeeds on second."""
call_count["second"] += 1
if call_count["second"] == 1:
return (False, "Second guardrail failed on first attempt")
return (True, f"Second({call_count['second']}): {result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always succeeds."""
call_count["third"] += 1
return (True, f"Third({call_count['third']}): {result.raw}")
agent = Mock()
agent.role = "retry_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test retry in middle guardrail",
expected_output="Retry handling",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
guardrail_max_retries=2,
)
result = task.execute_sync(agent=agent)
# Based on the test output, the behavior is different than expected
# The guardrails are called multiple times, so let's verify the retry happened
assert task.retry_count == 1
# Verify that the second guardrail eventually succeeded
assert "Second(2)" in result.raw or call_count["second"] >= 2
def test_multiple_guardrails_with_max_retries_exceeded():
"""Test that exception is raised when max retries exceeded with multiple guardrails."""
def passing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always passes."""
return (True, f"Passed: {result.raw}")
def failing_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Always fails."""
return (False, "This guardrail always fails")
agent = Mock()
agent.role = "failing_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test max retries with multiple guardrails",
expected_output="Will fail",
guardrails=[passing_guardrail, failing_guardrail],
guardrail_max_retries=1,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation after 1 retries" in str(exc_info.value)
assert "This guardrail always fails" in str(exc_info.value)
assert task.retry_count == 1
def test_multiple_guardrails_empty_list():
"""Test that empty guardrails list works correctly."""
agent = Mock()
agent.role = "empty_agent"
agent.execute_task.return_value = "no guardrails"
agent.crew = None
task = create_smart_task(
description="Test empty guardrails list",
expected_output="No processing",
guardrails=[],
)
result = task.execute_sync(agent=agent)
assert result.raw == "no guardrails"
def test_multiple_guardrails_with_llm_guardrails():
"""Test mixing callable and LLM guardrails."""
def callable_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Callable guardrail."""
return (True, f"Callable: {result.raw}")
# Create a proper mock agent without config issues
from crewai import Agent
agent = Agent(
role="mixed_guardrail_agent", goal="Test goal", backstory="Test backstory"
)
task = create_smart_task(
description="Test mixed guardrail types",
expected_output="Mixed processing",
guardrails=[callable_guardrail, "Ensure the output is professional"],
agent=agent,
)
# The LLM guardrail will be converted to LLMGuardrail internally
assert len(task._guardrails) == 2
assert callable(task._guardrails[0])
assert callable(task._guardrails[1]) # LLMGuardrail is callable
def test_multiple_guardrails_processing_order():
"""Test that guardrails are processed in the correct order."""
processing_order = []
def first_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("first")
return (True, f"1-{result.raw}")
def second_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("second")
return (True, f"2-{result.raw}")
def third_guardrail(result: TaskOutput) -> tuple[bool, str]:
processing_order.append("third")
return (True, f"3-{result.raw}")
agent = Mock()
agent.role = "order_agent"
agent.execute_task.return_value = "base"
agent.crew = None
task = create_smart_task(
description="Test processing order",
expected_output="Ordered processing",
guardrails=[first_guardrail, second_guardrail, third_guardrail],
)
result = task.execute_sync(agent=agent)
assert processing_order == ["first", "second", "third"]
assert result.raw == "3-2-1-base"
def test_multiple_guardrails_with_pydantic_output():
"""Test multiple guardrails with Pydantic output model."""
from pydantic import BaseModel, Field
class TestModel(BaseModel):
content: str = Field(description="The content")
processed: bool = Field(description="Whether it was processed")
def json_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Convert to JSON format."""
import json
data = {"content": result.raw, "processed": True}
return (True, json.dumps(data))
def validation_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Validate JSON structure."""
import json
try:
data = json.loads(result.raw)
if "content" not in data or "processed" not in data:
return (False, "Missing required fields")
return (True, result.raw)
except json.JSONDecodeError:
return (False, "Invalid JSON format")
agent = Mock()
agent.role = "pydantic_agent"
agent.execute_task.return_value = "test content"
agent.crew = None
task = create_smart_task(
description="Test guardrails with Pydantic",
expected_output="Structured output",
guardrails=[json_guardrail, validation_guardrail],
output_pydantic=TestModel,
)
result = task.execute_sync(agent=agent)
# Verify the result is valid JSON and can be parsed
import json
parsed = json.loads(result.raw)
assert parsed["content"] == "test content"
assert parsed["processed"] is True
def test_guardrails_vs_single_guardrail_mutual_exclusion():
"""Test that guardrails list nullifies single guardrail."""
def single_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""Single guardrail - should be ignored."""
return (True, f"Single: {result.raw}")
def list_guardrail(result: TaskOutput) -> tuple[bool, str]:
"""List guardrail - should be used."""
return (True, f"List: {result.raw}")
agent = Mock()
agent.role = "exclusion_agent"
agent.execute_task.return_value = "test"
agent.crew = None
task = create_smart_task(
description="Test mutual exclusion",
expected_output="Exclusion test",
guardrail=single_guardrail, # This should be ignored
guardrails=[list_guardrail], # This should be used
)
result = task.execute_sync(agent=agent)
# Should only use the guardrails list, not the single guardrail
assert result.raw == "List: test"
assert task._guardrail is None # Single guardrail should be nullified