Compare commits

..

5 Commits

Author SHA1 Message Date
Devin AI
f80fe7d4c1 fix: use unquoted type names in model descriptions
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 11:53:58 +00:00
Devin AI
da0d37af03 fix: ensure type names are quoted in model descriptions
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 11:50:37 +00:00
Devin AI
f65c31bfd0 style: fix import sorting
Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 11:46:37 +00:00
Devin AI
9322f06e7a refactor: address code review feedback
- Split describe_field into smaller functions
- Add error handling and logging
- Add comprehensive docstrings
- Add pytest marks for test organization
- Add edge case tests
- Add type hints and constants
- Add caching for performance

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 11:45:12 +00:00
Devin AI
326f406605 feat: enhance pydantic output to include field descriptions
- Update generate_model_description to include field descriptions
- Add tests for field description handling
- Maintain backward compatibility for fields without descriptions

Fixes #2188

Co-Authored-By: Joe Moura <joao@crewai.com>
2025-02-21 11:40:37 +00:00
6 changed files with 224 additions and 305 deletions

View File

@@ -3,20 +3,17 @@ import inspect
import json
import logging
import threading
import typing
import uuid
from concurrent.futures import Future
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import (
AbstractSet,
Any,
Callable,
ClassVar,
Dict,
List,
Mapping,
Optional,
Set,
Tuple,
@@ -35,7 +32,6 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
@@ -117,7 +113,7 @@ class Task(BaseModel):
description="Task output, it's final result after being executed", default=None
)
tools: Optional[List[BaseTool]] = Field(
default_factory=list[BaseTool],
default_factory=list,
description="Tools the agent is limited to use for this task.",
)
id: UUID4 = Field(
@@ -133,7 +129,7 @@ class Task(BaseModel):
description="A converter class used to export structured output",
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set[str])
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task",
@@ -155,8 +151,8 @@ class Task(BaseModel):
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one required positional parameter (the TaskOutput)
2. Checking return type annotations match tuple[bool, Any] or specific types like tuple[bool, str]
1. Verifying the function accepts exactly one parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present
3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because:
@@ -164,24 +160,6 @@ class Task(BaseModel):
- Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues
Examples:
Simple validation with new style annotation:
>>> def validate_output(result: TaskOutput) -> tuple[bool, str]:
... return (True, result.raw.upper())
Validation with optional parameters:
>>> def validate_with_options(result: TaskOutput, strict: bool = True) -> tuple[bool, str]:
... if strict and not result.raw.isupper():
... return (False, "Text must be uppercase")
... return (True, result.raw)
Validation with specific return type:
>>> def validate_task_output(result: TaskOutput) -> tuple[bool, TaskOutput]:
... if not result.raw:
... return (False, result)
... result.raw = result.raw.strip()
... return (True, result)
Args:
v: The guardrail function to validate
@@ -190,57 +168,22 @@ class Task(BaseModel):
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match tuple[bool, Any] or specific allowed types
doesn't match Tuple[bool, Any]
"""
if v is not None:
sig = inspect.signature(v)
# Get required positional parameters (excluding those with defaults)
required_params = [
param for param in sig.parameters.values()
if param.default == inspect.Parameter.empty
and param.kind in (inspect.Parameter.POSITIONAL_ONLY, inspect.Parameter.POSITIONAL_OR_KEYWORD)
]
keyword_only_params = [
param for param in sig.parameters.values()
if param.kind == inspect.Parameter.KEYWORD_ONLY
]
if len(required_params) != 1 or (len(keyword_only_params) > 0 and any(p.default == inspect.Parameter.empty for p in keyword_only_params)):
raise GuardrailValidationError(
"Guardrail function must accept exactly one required positional parameter and no required keyword-only parameters",
{"params": [str(p) for p in sig.parameters.values()]}
)
if len(sig.parameters) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it
type_hints = typing.get_type_hints(v)
return_annotation = type_hints.get('return')
if return_annotation:
# Convert annotation to string for comparison
annotation_str = str(return_annotation).lower().replace(' ', '')
# Normalize type strings
normalized_annotation = (
annotation_str.replace('typing.', '')
.replace('dict[str,typing.any]', 'dict[str,any]')
.replace('dict[str, any]', 'dict[str,any]')
)
VALID_RETURN_TYPES = {
'tuple[bool,any]',
'tuple[bool,str]',
'tuple[bool,dict[str,any]]',
'tuple[bool,taskoutput]'
}
# Check if the normalized annotation matches any valid pattern
is_valid = normalized_annotation == 'tuple[bool,any]'
if not is_valid:
is_valid = normalized_annotation in VALID_RETURN_TYPES
if not is_valid:
raise GuardrailValidationError(
f"Invalid return type annotation. Expected one of: "
f"{', '.join(VALID_RETURN_TYPES)}",
{"got": annotation_str}
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
if not (
return_annotation == Tuple[bool, Any]
or str(return_annotation) == "Tuple[bool, Any]"
):
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
)
return v
@@ -468,7 +411,6 @@ class Task(BaseModel):
"Task guardrail returned None as result. This is not allowed."
)
# Handle different result types
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
@@ -478,13 +420,6 @@ class Task(BaseModel):
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
elif isinstance(guardrail_result.result, dict):
task_output.raw = guardrail_result.result
task_output.json_dict = guardrail_result.result
pydantic_output, _ = self._export_output(
json.dumps(guardrail_result.result)
)
task_output.pydantic = pydantic_output
self.output = task_output
self.end_time = datetime.datetime.now()
@@ -675,74 +610,40 @@ class Task(BaseModel):
self.delegations += 1
def copy(
self,
agents: List["BaseAgent"] | None = None,
task_mapping: Dict[str, "Task"] | None = None,
*,
include: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
exclude: AbstractSet[int] | AbstractSet[str] | Mapping[int, Any] | Mapping[str, Any] | None = None,
update: dict[str, Any] | None = None,
deep: bool = False,
self, agents: List["BaseAgent"], task_mapping: Dict[str, "Task"]
) -> "Task":
"""Create a deep copy of the Task.
Args:
agents: Optional list of agents to copy agent references
task_mapping: Optional mapping of task keys to tasks for context
include: Fields to include in the copy
exclude: Fields to exclude from the copy
update: Fields to update in the copy
deep: Whether to perform a deep copy
"""
if agents is None and task_mapping is None:
# New style copy using BaseModel
copied = super().copy(
include=include,
exclude=exclude,
update=update,
deep=deep,
)
# Copy mutable fields
if self.tools:
copied.tools = copy(self.tools)
if self.context:
copied.context = copy(self.context)
return copied
# Legacy copy behavior
exclude_fields = {
"""Create a deep copy of the Task."""
exclude = {
"id",
"agent",
"context",
"tools",
}
copied_data = self.model_dump(exclude=exclude_fields)
copied_data = self.model_dump(exclude=exclude)
copied_data = {k: v for k, v in copied_data.items() if v is not None}
cloned_context = (
[task_mapping[context_task.key] for context_task in self.context]
if self.context and task_mapping
if self.context
else None
)
def get_agent_by_role(role: str) -> Union["BaseAgent", None]:
if not agents:
return None
return next((agent for agent in agents if agent.role == role), None)
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
cloned_tools = copy(self.tools) if self.tools else []
return Task(
copied_task = Task(
**copied_data,
context=cloned_context,
agent=cloned_agent,
tools=cloned_tools,
)
return copied_task
def _export_output(
self, result: str
) -> Tuple[Optional[BaseModel], Optional[Dict[str, Any]]]:

View File

@@ -1,25 +0,0 @@
"""
Module for task-related exceptions.
This module provides custom exceptions used throughout the task system
to provide more specific error handling and context.
"""
from typing import Any, Dict, Optional
class GuardrailValidationError(Exception):
"""Exception raised for guardrail validation errors.
This exception provides detailed context about why a guardrail
validation failed, including the specific validation that failed
and any relevant context information.
Attributes:
message: A clear description of the validation error
context: Optional dictionary containing additional error context
"""
def __init__(self, message: str, context: Optional[Dict[str, Any]] = None):
self.message = message
self.context = context or {}
super().__init__(self.message)

View File

@@ -1,5 +1,5 @@
import json
from typing import Any, Dict, Optional, Union
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, model_validator
@@ -15,7 +15,7 @@ class TaskOutput(BaseModel):
description="Expected output of the task", default=None
)
summary: Optional[str] = Field(description="Summary of the task", default=None)
raw: Any = Field(description="Raw output of the task", default="")
raw: str = Field(description="Raw output of the task", default="")
pydantic: Optional[BaseModel] = Field(
description="Pydantic output of task", default=None
)

View File

@@ -1,5 +1,7 @@
import json
import logging
import re
from functools import lru_cache
from typing import Any, Optional, Type, Union, get_args, get_origin
from pydantic import BaseModel, ValidationError
@@ -8,6 +10,8 @@ from crewai.agents.agent_builder.utilities.base_output_converter import OutputCo
from crewai.utilities.printer import Printer
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
logger = logging.getLogger(__name__)
class ConverterError(Exception):
"""Error raised when Converter fails to parse the input."""
@@ -253,17 +257,57 @@ def create_converter(
return converter
FIELD_TYPE_KEY = "type"
FIELD_DESC_KEY = "description"
def generate_model_description(model: Type[BaseModel]) -> str:
"""
Generate a string description of a Pydantic model's fields and their types.
This function takes a Pydantic model class and returns a string that describes
the model's fields and their respective types. The description includes handling
of complex types such as `Optional`, `List`, and `Dict`, as well as nested Pydantic
models.
@lru_cache(maxsize=100)
def generate_model_description(model: Type[BaseModel]) -> str:
models and field descriptions when available.
Args:
model: A Pydantic BaseModel class to generate description for
Returns:
str: A JSON-like string describing the model's fields, their types, and descriptions
"""
def describe_field(field_type):
def describe_field(field_type: Any, field_info: Optional[Any] = None) -> Union[str, dict]:
"""
Generate a description for a model field including its type and description.
Args:
field_type: The type annotation of the field
field_info: Optional field information containing description
Returns:
Union[str, dict]: Field description either as string (type only) or
dict with type and description
"""
try:
type_desc = get_type_description(field_type)
if field_info and field_info.description:
return {FIELD_TYPE_KEY: type_desc, FIELD_DESC_KEY: field_info.description}
return type_desc
except Exception as e:
logger.warning(f"Error processing field description: {e}")
return str(field_type)
def get_type_description(field_type: Any) -> str:
"""
Get the type description for a field type.
Args:
field_type: The type annotation to describe
Returns:
str: A string representation of the type
"""
origin = get_origin(field_type)
args = get_args(field_type)
@@ -271,14 +315,14 @@ def generate_model_description(model: Type[BaseModel]) -> str:
# Handle both Union and the new '|' syntax
non_none_args = [arg for arg in args if arg is not type(None)]
if len(non_none_args) == 1:
return f"Optional[{describe_field(non_none_args[0])}]"
return f"Optional[{get_type_description(non_none_args[0])}]"
else:
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
return f"Optional[Union[{', '.join(get_type_description(arg) for arg in non_none_args)}]]"
elif origin is list:
return f"List[{describe_field(args[0])}]"
return f"List[{get_type_description(args[0])}]"
elif origin is dict:
key_type = describe_field(args[0])
value_type = describe_field(args[1])
key_type = get_type_description(args[0])
value_type = get_type_description(args[1])
return f"Dict[{key_type}, {value_type}]"
elif isinstance(field_type, type) and issubclass(field_type, BaseModel):
return generate_model_description(field_type)
@@ -287,8 +331,12 @@ def generate_model_description(model: Type[BaseModel]) -> str:
else:
return str(field_type)
fields = model.__annotations__
field_descriptions = [
f'"{name}": {describe_field(type_)}' for name, type_ in fields.items()
]
fields = model.model_fields
field_descriptions = []
for name, field in fields.items():
field_desc = describe_field(field.annotation, field)
if isinstance(field_desc, dict):
field_descriptions.append(f'"{name}": {json.dumps(field_desc)}')
else:
field_descriptions.append(f'"{name}": {field_desc}')
return "{\n " + ",\n ".join(field_descriptions) + "\n}"

View File

@@ -1,179 +1,129 @@
"""Tests for task guardrails functionality."""
from typing import Any, Dict
from unittest.mock import Mock
import pytest
from crewai.task import Task
from crewai.tasks.exceptions import GuardrailValidationError
from crewai.tasks.task_output import TaskOutput
class TestTaskGuardrails:
"""Test suite for task guardrail functionality."""
def test_task_without_guardrail():
"""Test that tasks work normally without guardrails (backward compatibility)."""
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
@pytest.fixture
def mock_agent(self):
"""Fixture providing a mock agent for testing."""
agent = Mock()
agent.role = "test_agent"
agent.crew = None
return agent
task = Task(description="Test task", expected_output="Output")
def test_task_without_guardrail(self, mock_agent):
"""Test that tasks work normally without guardrails (backward compatibility)."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output")
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test result"
def test_task_with_successful_guardrail(self, mock_agent):
"""Test that successful guardrail validation passes transformed result."""
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
def test_task_with_successful_guardrail():
"""Test that successful guardrail validation passes transformed result."""
mock_agent.execute_task.return_value = "test result"
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
def guardrail(result: TaskOutput):
return (True, result.raw.upper())
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_task_with_failing_guardrail(self, mock_agent):
"""Test that failing guardrail triggers retry with error context."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def test_task_with_failing_guardrail():
"""Test that failing guardrail triggers retry with error context."""
mock_agent.execute_task.side_effect = ["bad result", "good result"]
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
def guardrail(result: TaskOutput):
return (False, "Invalid format")
# First execution fails guardrail, second succeeds
mock_agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.side_effect = ["bad result", "good result"]
agent.crew = None
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
# First execution fails guardrail, second succeeds
agent.execute_task.side_effect = ["bad result", "good result"]
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert "Task failed guardrail validation" in str(exc_info.value)
assert task.retry_count == 1
def test_task_with_guardrail_retries(self, mock_agent):
"""Test that guardrail respects max_retries configuration."""
def guardrail(result: TaskOutput):
return (False, "Invalid format")
def test_task_with_guardrail_retries():
"""Test that guardrail respects max_retries configuration."""
mock_agent.execute_task.return_value = "bad result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
def guardrail(result: TaskOutput):
return (False, "Invalid format")
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
agent = Mock()
agent.role = "test_agent"
agent.execute_task.return_value = "bad result"
agent.crew = None
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=2,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
assert task.retry_count == 2
assert "Task failed guardrail validation after 2 retries" in str(exc_info.value)
assert "Invalid format" in str(exc_info.value)
def test_guardrail_error_in_context(self, mock_agent):
"""Test that guardrail error is passed in context for retry."""
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
def test_guardrail_error_in_context():
"""Test that guardrail error is passed in context for retry."""
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
def guardrail(result: TaskOutput):
return (False, "Expected JSON, got string")
# Mock execute_task to succeed on second attempt
first_call = True
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
agent = Mock()
agent.role = "test_agent"
agent.crew = None
mock_agent.execute_task.side_effect = execute_task
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail,
max_retries=1,
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=mock_agent)
# Mock execute_task to succeed on second attempt
first_call = True
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)
def execute_task(task, context, tools):
nonlocal first_call
if first_call:
first_call = False
return "invalid"
return '{"valid": "json"}'
agent.execute_task.side_effect = execute_task
def test_guardrail_with_new_style_annotation(self, mock_agent):
"""Test guardrail with new style tuple annotation."""
def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper())
mock_agent.execute_task.return_value = "test result"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
with pytest.raises(Exception) as exc_info:
task.execute_sync(agent=agent)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_guardrail_with_optional_params(self, mock_agent):
"""Test guardrail with optional parameters."""
def guardrail(result: TaskOutput, optional_param: str = "default") -> tuple[bool, str]:
return (True, f"{result.raw}-{optional_param}")
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == "test-default"
def test_guardrail_with_invalid_optional_params(self, mock_agent):
"""Test guardrail with invalid optional parameters."""
def guardrail(result: TaskOutput, *, required_kwonly: str) -> tuple[bool, str]:
return (True, result.raw)
with pytest.raises(GuardrailValidationError) as exc_info:
Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
assert "exactly one required positional parameter" in str(exc_info.value)
def test_guardrail_with_dict_return_type(self, mock_agent):
"""Test guardrail with dict return type."""
def guardrail(result: TaskOutput) -> tuple[bool, dict[str, Any]]:
return (True, {"processed": result.raw.upper()})
mock_agent.execute_task.return_value = "test"
task = Task(
description="Test task",
expected_output="Output",
guardrail=guardrail
)
result = task.execute_sync(agent=mock_agent)
assert isinstance(result, TaskOutput)
assert result.raw == {"processed": "TEST"}
assert "Task failed guardrail validation" in str(exc_info.value)
assert "Expected JSON, got string" in str(exc_info.value)

View File

@@ -4,7 +4,7 @@ from typing import Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field
from crewai.llm import LLM
from crewai.utilities.converter import (
@@ -328,6 +328,51 @@ def test_generate_model_description_dict_field():
assert description == expected_description
@pytest.mark.field_descriptions
def test_generate_model_description_with_field_descriptions():
"""
Verify that the model description generator correctly includes field descriptions
when they are provided via Field(..., description='...').
"""
class ModelWithDescriptions(BaseModel):
name: str = Field(..., description="The user's full name")
age: int = Field(..., description="The user's age in years")
description = generate_model_description(ModelWithDescriptions)
expected = '{\n "name": {"type": "str", "description": "The user\'s full name"},\n "age": {"type": "int", "description": "The user\'s age in years"}\n}'
assert description == expected
@pytest.mark.field_descriptions
def test_generate_model_description_mixed_fields():
"""
Verify that the model description generator correctly handles a mix of fields
with and without descriptions.
"""
class MixedModel(BaseModel):
name: str = Field(..., description="The user's name")
age: int # No description
description = generate_model_description(MixedModel)
expected = '{\n "name": {"type": "str", "description": "The user\'s name"},\n "age": int\n}'
assert description == expected
@pytest.mark.field_descriptions
def test_generate_model_description_with_empty_description():
"""
Verify that the model description generator correctly handles fields with empty
descriptions by treating them as fields without descriptions.
"""
class ModelWithEmptyDescription(BaseModel):
name: str = Field(..., description="")
age: int = Field(..., description=None)
description = generate_model_description(ModelWithEmptyDescription)
expected = '{\n "name": str,\n "age": int\n}'
assert description == expected
@pytest.mark.vcr(filter_headers=["authorization"])
def test_convert_with_instructions():
llm = LLM(model="gpt-4o-mini")