Merge branch 'main' into joaomdmoura/multimodal-crew

This commit is contained in:
João Moura
2024-12-26 23:33:37 -03:00
committed by GitHub
17 changed files with 2015 additions and 179 deletions

View File

@@ -80,10 +80,27 @@ def listen(condition):
return decorator
def router(method):
def router(condition):
def decorator(func):
func.__is_router__ = True
func.__router_for__ = method.__name__
# Handle conditions like listen/start
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
else:
raise ValueError(
"Condition must be a method, string, or a result of or_() or and_()"
)
return func
return decorator
@@ -123,8 +140,8 @@ class FlowMeta(type):
start_methods = []
listeners = {}
routers = {}
router_paths = {}
routers = set()
for attr_name, attr_value in dct.items():
if hasattr(attr_value, "__is_start_method__"):
@@ -137,18 +154,11 @@ class FlowMeta(type):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
listeners[attr_name] = (condition_type, methods)
elif hasattr(attr_value, "__is_router__"):
routers[attr_value.__router_for__] = attr_name
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns
# Register router as a listener to its triggering method
trigger_method_name = attr_value.__router_for__
methods = [trigger_method_name]
condition_type = "OR"
listeners[attr_name] = (condition_type, methods)
if hasattr(attr_value, "__is_router__") and attr_value.__is_router__:
routers.add(attr_name)
possible_returns = get_possible_return_constants(attr_value)
if possible_returns:
router_paths[attr_name] = possible_returns
setattr(cls, "_start_methods", start_methods)
setattr(cls, "_listeners", listeners)
@@ -163,7 +173,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
_start_methods: List[str] = []
_listeners: Dict[str, tuple[str, List[str]]] = {}
_routers: Dict[str, str] = {}
_routers: Set[str] = set()
_router_paths: Dict[str, List[str]] = {}
initial_state: Union[Type[T], T, None] = None
event_emitter = Signal("event_emitter")
@@ -210,20 +220,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
return self._method_outputs
def _initialize_state(self, inputs: Dict[str, Any]) -> None:
"""
Initializes or updates the state with the provided inputs.
Args:
inputs: Dictionary of inputs to initialize or update the state.
Raises:
ValueError: If inputs do not match the structured state model.
TypeError: If state is neither a BaseModel instance nor a dictionary.
"""
if isinstance(self._state, BaseModel):
# Structured state management
# Structured state
try:
# Define a function to create the dynamic class
def create_model_with_extra_forbid(
base_model: Type[BaseModel],
) -> Type[BaseModel]:
@@ -233,34 +233,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
return ModelWithExtraForbid
# Create the dynamic class
ModelWithExtraForbid = create_model_with_extra_forbid(
self._state.__class__
)
# Create a new instance using the combined state and inputs
self._state = cast(
T, ModelWithExtraForbid(**{**self._state.model_dump(), **inputs})
)
except ValidationError as e:
raise ValueError(f"Invalid inputs for structured state: {e}") from e
elif isinstance(self._state, dict):
# Unstructured state management
self._state.update(inputs)
else:
raise TypeError("State must be a BaseModel instance or a dictionary.")
def kickoff(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow synchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
self.event_emitter.send(
self,
event=FlowStartedEvent(
@@ -274,15 +260,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
return asyncio.run(self.kickoff_async())
async def kickoff_async(self, inputs: Optional[Dict[str, Any]] = None) -> Any:
"""
Starts the execution of the flow asynchronously.
Args:
inputs: Optional dictionary of inputs to initialize or update the state.
Returns:
The final output from the flow execution.
"""
if not self._start_methods:
raise ValueError("No start method defined")
@@ -290,16 +267,12 @@ class Flow(Generic[T], metaclass=FlowMeta):
self.__class__.__name__, list(self._methods.keys())
)
# Create tasks for all start methods
tasks = [
self._execute_start_method(start_method)
for start_method in self._start_methods
]
# Run all start methods concurrently
await asyncio.gather(*tasks)
# Determine the final output (from the last executed method)
final_output = self._method_outputs[-1] if self._method_outputs else None
self.event_emitter.send(
@@ -310,7 +283,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
result=final_output,
),
)
return final_output
async def _execute_start_method(self, start_method_name: str) -> None:
@@ -327,49 +299,68 @@ class Flow(Generic[T], metaclass=FlowMeta):
if asyncio.iscoroutinefunction(method)
else method(*args, **kwargs)
)
self._method_outputs.append(result) # Store the output
# Track method execution counts
self._method_outputs.append(result)
self._method_execution_counts[method_name] = (
self._method_execution_counts.get(method_name, 0) + 1
)
return result
async def _execute_listeners(self, trigger_method: str, result: Any) -> None:
listener_tasks = []
if trigger_method in self._routers:
router_method = self._methods[self._routers[trigger_method]]
path = await self._execute_method(
self._routers[trigger_method], router_method
# First, handle routers repeatedly until no router triggers anymore
while True:
routers_triggered = self._find_triggered_methods(
trigger_method, router_only=True
)
trigger_method = path
if not routers_triggered:
break
for router_name in routers_triggered:
await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path
# The last router executed sets the trigger_method
# The router result is the last element in self._method_outputs
trigger_method = self._method_outputs[-1]
# Now that no more routers are triggered by current trigger_method,
# execute normal listeners
listeners_triggered = self._find_triggered_methods(
trigger_method, router_only=False
)
if listeners_triggered:
tasks = [
self._execute_single_listener(listener_name, result)
for listener_name in listeners_triggered
]
await asyncio.gather(*tasks)
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> List[str]:
triggered = []
for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
continue
if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
# Schedule the listener without preventing re-execution
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
self._pending_and_listeners[listener_name].discard(trigger_method)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
listener_tasks.append(
self._execute_single_listener(listener_name, result)
)
triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
# Run all listener tasks concurrently and wait for them to complete
if listener_tasks:
await asyncio.gather(*listener_tasks)
return triggered
async def _execute_single_listener(self, listener_name: str, result: Any) -> None:
try:
@@ -386,17 +377,13 @@ class Flow(Generic[T], metaclass=FlowMeta):
sig = inspect.signature(method)
params = list(sig.parameters.values())
# Exclude 'self' parameter
method_params = [p for p in params if p.name != "self"]
if method_params:
# If listener expects parameters, pass the result
listener_result = await self._execute_method(
listener_name, method, result
)
else:
# If listener does not expect parameters, call without arguments
listener_result = await self._execute_method(listener_name, method)
self.event_emitter.send(
@@ -408,8 +395,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
),
)
# Execute listeners of this listener
# Execute listeners (and possibly routers) of this listener
await self._execute_listeners(listener_name, listener_result)
except Exception as e:
print(
f"[Flow._execute_single_listener] Error in method {listener_name}: {e}"
@@ -422,5 +410,4 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._telemetry.flow_plotting_span(
self.__class__.__name__, list(self._methods.keys())
)
plot_flow(self, filename)

View File

@@ -31,16 +31,50 @@ def get_possible_return_constants(function):
print(f"Source code:\n{source}")
return None
return_values = []
return_values = set()
dict_definitions = {}
class DictionaryAssignmentVisitor(ast.NodeVisitor):
def visit_Assign(self, node):
# Check if this assignment is assigning a dictionary literal to a variable
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
target = node.targets[0]
if isinstance(target, ast.Name):
var_name = target.id
dict_values = []
# Extract string values from the dictionary
for val in node.value.values:
if isinstance(val, ast.Constant) and isinstance(val.value, str):
dict_values.append(val.value)
# If non-string, skip or just ignore
if dict_values:
dict_definitions[var_name] = dict_values
self.generic_visit(node)
class ReturnVisitor(ast.NodeVisitor):
def visit_Return(self, node):
# Check if the return value is a constant (Python 3.8+)
if isinstance(node.value, ast.Constant):
return_values.append(node.value.value)
# Direct string return
if isinstance(node.value, ast.Constant) and isinstance(
node.value.value, str
):
return_values.add(node.value.value)
# Dictionary-based return, like return paths[result]
elif isinstance(node.value, ast.Subscript):
# Check if we're subscripting a known dictionary variable
if isinstance(node.value.value, ast.Name):
var_name = node.value.value.id
if var_name in dict_definitions:
# Add all possible dictionary values
for v in dict_definitions[var_name]:
return_values.add(v)
self.generic_visit(node)
# First pass: identify dictionary assignments
DictionaryAssignmentVisitor().visit(code_ast)
# Second pass: identify returns
ReturnVisitor().visit(code_ast)
return return_values
return list(return_values) if return_values else None
def calculate_node_levels(flow):
@@ -61,10 +95,7 @@ def calculate_node_levels(flow):
current_level = levels[current]
visited.add(current)
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
if current in trigger_methods:
if (
@@ -89,7 +120,7 @@ def calculate_node_levels(flow):
queue.append(listener_name)
# Handle router connections
if current in flow._routers.values():
if current in flow._routers:
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
@@ -105,6 +136,7 @@ def calculate_node_levels(flow):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels
@@ -142,7 +174,7 @@ def dfs_ancestors(node, ancestors, visited, flow):
dfs_ancestors(listener_name, ancestors, visited, flow)
# Handle router methods separately
if node in flow._routers.values():
if node in flow._routers:
router_method_name = node
paths = flow._router_paths.get(router_method_name, [])
for path in paths:

View File

@@ -94,12 +94,14 @@ def add_edges(net, flow, node_positions, colors):
ancestors = build_ancestor_dict(flow)
parent_children = build_parent_children_dict(flow)
# Edges for normal listeners
for method_name in flow._listeners:
condition_type, trigger_methods = flow._listeners[method_name]
is_and_condition = condition_type == "AND"
for trigger in trigger_methods:
if trigger in flow._methods or trigger in flow._routers.values():
# Check if nodes exist before adding edges
if trigger in node_positions and method_name in node_positions:
is_router_edge = any(
trigger in paths for paths in flow._router_paths.values()
)
@@ -135,7 +137,22 @@ def add_edges(net, flow, node_positions, colors):
}
net.add_edge(trigger, method_name, **edge_style)
else:
# Nodes not found in node_positions. Check if it's a known router outcome and a known method.
is_router_edge = any(
trigger in paths for paths in flow._router_paths.values()
)
# Check if method_name is a known method
method_known = method_name in flow._methods
# If it's a known router edge and the method is known, don't warn.
# This means the path is legitimate, just not reflected as nodes here.
if not (is_router_edge and method_known):
print(
f"Warning: No node found for '{trigger}' or '{method_name}'. Skipping edge."
)
# Edges for router return paths
for router_method_name, paths in flow._router_paths.items():
for path in paths:
for listener_name, (
@@ -143,36 +160,49 @@ def add_edges(net, flow, node_positions, colors):
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
is_cycle_edge = is_ancestor(trigger, method_name, ancestors)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if (
router_method_name in node_positions
and listener_name in node_positions
):
is_cycle_edge = is_ancestor(
router_method_name, listener_name, ancestors
)
parent_has_multiple_children = (
len(parent_children.get(router_method_name, [])) > 1
)
needs_curvature = is_cycle_edge or parent_has_multiple_children
if needs_curvature:
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if needs_curvature:
source_pos = node_positions.get(router_method_name)
target_pos = node_positions.get(listener_name)
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
if source_pos and target_pos:
dx = target_pos[0] - source_pos[0]
smooth_type = "curvedCCW" if dx <= 0 else "curvedCW"
index = get_child_index(
router_method_name, listener_name, parent_children
)
edge_smooth = {
"type": smooth_type,
"roundness": 0.2 + (0.1 * index),
}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = {"type": "cubicBezier"}
else:
edge_smooth = False
edge_smooth = False
edge_style = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
edge_style = {
"color": colors["router_edge"],
"width": 2,
"arrows": "to",
"dashes": True,
"smooth": edge_smooth,
}
net.add_edge(router_method_name, listener_name, **edge_style)
else:
# Same check here: known router edge and known method?
method_known = listener_name in flow._methods
if not method_known:
print(
f"Warning: No node found for '{router_method_name}' or '{listener_name}'. Skipping edge."
)

View File

@@ -1,8 +1,8 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Dict, List, Union
from typing import Dict, List, Optional, Union
from pydantic import Field
from pydantic import Field, field_validator
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.storage.knowledge_storage import KnowledgeStorage
@@ -14,17 +14,28 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Base class for knowledge sources that load content from files."""
_logger: Logger = Logger(verbose=True)
file_path: Union[Path, List[Path], str, List[str]] = Field(
..., description="The path to the file"
file_path: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default=None,
description="[Deprecated] The path to the file. Use file_paths instead.",
)
file_paths: Optional[Union[Path, List[Path], str, List[str]]] = Field(
default_factory=list, description="The path to the file"
)
content: Dict[Path, str] = Field(init=False, default_factory=dict)
storage: KnowledgeStorage = Field(default_factory=KnowledgeStorage)
safe_file_paths: List[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, values):
"""Validate that at least one of file_path or file_paths is provided."""
if v is None and ("file_path" not in values or values.get("file_path") is None):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _):
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_paths()
self.validate_content()
self.content = self.load_content()
@abstractmethod
@@ -32,7 +43,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
pass
def validate_paths(self):
def validate_content(self):
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -59,13 +70,30 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def _process_file_paths(self) -> List[Path]:
"""Convert file_path to a list of Path objects."""
paths = (
[self.file_path]
if isinstance(self.file_path, (str, Path))
else self.file_path
if hasattr(self, "file_path") and self.file_path is not None:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path
if self.file_paths is None:
raise ValueError("Your source must be provided with a file_paths: []")
# Convert single path to list
path_list: List[Union[Path, str]] = (
[self.file_paths]
if isinstance(self.file_paths, (str, Path))
else list(self.file_paths)
if isinstance(self.file_paths, list)
else []
)
if not isinstance(paths, list):
raise ValueError("file_path must be a Path, str, or a list of these types")
if not path_list:
raise ValueError(
"file_path/file_paths must be a Path, str, or a list of these types"
)
return [self.convert_to_path(path) for path in paths]
return [self.convert_to_path(path) for path in path_list]

View File

@@ -21,7 +21,7 @@ class BaseKnowledgeSource(BaseModel, ABC):
collection_name: Optional[str] = Field(default=None)
@abstractmethod
def load_content(self) -> Dict[Any, str]:
def validate_content(self) -> Any:
"""Load and preprocess content from the source."""
pass

View File

@@ -0,0 +1,120 @@
from pathlib import Path
from typing import Iterator, List, Optional, Union
from urllib.parse import urlparse
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.utilities.constants import KNOWLEDGE_DIRECTORY
from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""
_logger: Logger = Logger(verbose=True)
file_path: Optional[List[Union[Path, str]]] = Field(default=None)
file_paths: List[Union[Path, str]] = Field(default_factory=list)
chunks: List[str] = Field(default_factory=list)
safe_file_paths: List[Union[Path, str]] = Field(default_factory=list)
content: List[DoclingDocument] = Field(default_factory=list)
document_converter: DocumentConverter = Field(
default_factory=lambda: DocumentConverter(
allowed_formats=[
InputFormat.MD,
InputFormat.ASCIIDOC,
InputFormat.PDF,
InputFormat.DOCX,
InputFormat.HTML,
InputFormat.IMAGE,
InputFormat.XLSX,
InputFormat.PPTX,
]
)
)
def model_post_init(self, _) -> None:
if self.file_path:
self._logger.log(
"warning",
"The 'file_path' attribute is deprecated and will be removed in a future version. Please use 'file_paths' instead.",
color="yellow",
)
self.file_paths = self.file_path
self.safe_file_paths = self.validate_content()
self.content = self._load_content()
def _load_content(self) -> List[DoclingDocument]:
try:
return self._convert_source_to_docling_documents()
except ConversionError as e:
self._logger.log(
"error",
f"Error loading content: {e}. Supported formats: {self.document_converter.allowed_formats}",
"red",
)
raise e
except Exception as e:
self._logger.log("error", f"Error loading content: {e}")
raise e
def add(self) -> None:
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
def _convert_source_to_docling_documents(self) -> List[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]
def _chunk_doc(self, doc: DoclingDocument) -> Iterator[str]:
chunker = HierarchicalChunker()
for chunk in chunker.chunk(doc):
yield chunk.text
def validate_content(self) -> List[Union[Path, str]]:
processed_paths: List[Union[Path, str]] = []
for path in self.file_paths:
if isinstance(path, str):
if path.startswith(("http://", "https://")):
try:
if self._validate_url(path):
processed_paths.append(path)
else:
raise ValueError(f"Invalid URL format: {path}")
except Exception as e:
raise ValueError(f"Invalid URL: {path}. Error: {str(e)}")
else:
local_path = Path(KNOWLEDGE_DIRECTORY + "/" + path)
if local_path.exists():
processed_paths.append(local_path)
else:
raise FileNotFoundError(f"File not found: {local_path}")
else:
# this is an instance of Path
processed_paths.append(path)
return processed_paths
def _validate_url(self, url: str) -> bool:
try:
result = urlparse(url)
return all(
[
result.scheme in ("http", "https"),
result.netloc,
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
]
)
except Exception:
return False

View File

@@ -13,9 +13,9 @@ class StringKnowledgeSource(BaseKnowledgeSource):
def model_post_init(self, _):
"""Post-initialization method to validate content."""
self.load_content()
self.validate_content()
def load_content(self):
def validate_content(self):
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")

View File

@@ -1,12 +1,25 @@
import datetime
import inspect
import json
import logging
import threading
import uuid
from concurrent.futures import Future
from copy import copy
from hashlib import md5
from pathlib import Path
from typing import Any, Dict, List, Optional, Set, Tuple, Type, Union
from typing import (
Any,
Callable,
ClassVar,
Dict,
List,
Optional,
Set,
Tuple,
Type,
Union,
)
from opentelemetry.trace import Span
from pydantic import (
@@ -20,6 +33,7 @@ from pydantic import (
from pydantic_core import PydanticCustomError
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.tasks.guardrail_result import GuardrailResult
from crewai.tasks.output_format import OutputFormat
from crewai.tasks.task_output import TaskOutput
from crewai.telemetry.telemetry import Telemetry
@@ -49,6 +63,7 @@ class Task(BaseModel):
"""
__hash__ = object.__hash__ # type: ignore
logger: ClassVar[logging.Logger] = logging.getLogger(__name__)
used_tools: int = 0
tools_errors: int = 0
delegations: int = 0
@@ -110,6 +125,55 @@ class Task(BaseModel):
default=None,
)
processed_by_agents: Set[str] = Field(default_factory=set)
guardrail: Optional[Callable[[TaskOutput], Tuple[bool, Any]]] = Field(
default=None,
description="Function to validate task output before proceeding to next task"
)
max_retries: int = Field(
default=3,
description="Maximum number of retries when guardrail fails"
)
retry_count: int = Field(
default=0,
description="Current number of retries"
)
@field_validator("guardrail")
@classmethod
def validate_guardrail_function(cls, v: Optional[Callable]) -> Optional[Callable]:
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one parameter (the TaskOutput)
2. Checking return type annotations match Tuple[bool, Any] if present
3. Providing clear, immediate error messages for debugging
This runtime validation is crucial because:
- Type hints are optional and can be ignored at runtime
- Function signatures need immediate validation before task execution
- Clear error messages help users debug guardrail implementation issues
Args:
v: The guardrail function to validate
Returns:
The validated guardrail function
Raises:
ValueError: If the function signature is invalid or return annotation
doesn't match Tuple[bool, Any]
"""
if v is not None:
sig = inspect.signature(v)
if len(sig.parameters) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
if not (return_annotation == Tuple[bool, Any] or str(return_annotation) == 'Tuple[bool, Any]'):
raise ValueError("If return type is annotated, it must be Tuple[bool, Any]")
return v
_telemetry: Telemetry = PrivateAttr(default_factory=Telemetry)
_execution_span: Optional[Span] = PrivateAttr(default=None)
@@ -254,7 +318,6 @@ class Task(BaseModel):
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name,
description=self.description,
@@ -265,6 +328,37 @@ class Task(BaseModel):
agent=agent.role,
output_format=self._get_output_format(),
)
if self.guardrail:
guardrail_result = GuardrailResult.from_tuple(self.guardrail(task_output))
if not guardrail_result.success:
if self.retry_count >= self.max_retries:
raise Exception(
f"Task failed guardrail validation after {self.max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
self.retry_count += 1
context = (
f"### Previous attempt failed validation: {guardrail_result.error}\n\n\n"
f"### Previous result:\n{task_output.raw}\n\n\n"
"Try again, making sure to address the validation error."
)
return self._execute_core(agent, context, tools)
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(guardrail_result.result)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
self.output = task_output
self._set_end_execution_time(start_time)
@@ -308,7 +402,18 @@ class Task(BaseModel):
if inputs:
self.description = self._original_description.format(**inputs)
self.expected_output = self._original_expected_output.format(**inputs)
self.expected_output = self.interpolate_only(
input_string=self._original_expected_output, inputs=inputs
)
def interpolate_only(self, input_string: str, inputs: Dict[str, Any]) -> str:
"""Interpolate placeholders (e.g., {key}) in a string while leaving JSON untouched."""
escaped_string = input_string.replace("{", "{{").replace("}", "}}")
for key in inputs.keys():
escaped_string = escaped_string.replace(f"{{{{{key}}}}}", f"{{{key}}}")
return escaped_string.format(**inputs)
def increment_tools_errors(self) -> None:
"""Increment the tools errors counter."""
@@ -390,22 +495,33 @@ class Task(BaseModel):
return OutputFormat.RAW
def _save_file(self, result: Any) -> None:
"""Save task output to a file.
Args:
result: The result to save to the file. Can be a dict or any stringifiable object.
Raises:
ValueError: If output_file is not set
RuntimeError: If there is an error writing to the file
"""
if self.output_file is None:
raise ValueError("output_file is not set.")
resolved_path = Path(self.output_file).expanduser().resolve()
directory = resolved_path.parent
try:
resolved_path = Path(self.output_file).expanduser().resolve()
directory = resolved_path.parent
if not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
if not directory.exists():
directory.mkdir(parents=True, exist_ok=True)
with resolved_path.open("w", encoding="utf-8") as file:
if isinstance(result, dict):
import json
json.dump(result, file, ensure_ascii=False, indent=2)
else:
file.write(str(result))
with resolved_path.open("w", encoding="utf-8") as file:
if isinstance(result, dict):
import json
json.dump(result, file, ensure_ascii=False, indent=2)
else:
file.write(str(result))
except (OSError, IOError) as e:
raise RuntimeError(f"Failed to save output file: {e}")
return None
def __repr__(self):

View File

@@ -0,0 +1,56 @@
"""
Module for handling task guardrail validation results.
This module provides the GuardrailResult class which standardizes
the way task guardrails return their validation results.
"""
from typing import Any, Optional, Tuple, Union
from pydantic import BaseModel, field_validator
class GuardrailResult(BaseModel):
"""Result from a task guardrail execution.
This class standardizes the return format of task guardrails,
converting tuple responses into a structured format that can
be easily handled by the task execution system.
Attributes:
success (bool): Whether the guardrail validation passed
result (Any, optional): The validated/transformed result if successful
error (str, optional): Error message if validation failed
"""
success: bool
result: Optional[Any] = None
error: Optional[str] = None
@field_validator("result", "error")
@classmethod
def validate_result_error_exclusivity(cls, v: Any, info) -> Any:
values = info.data
if "success" in values:
if values["success"] and v and "error" in values and values["error"]:
raise ValueError("Cannot have both result and error when success is True")
if not values["success"] and v and "result" in values and values["result"]:
raise ValueError("Cannot have both result and error when success is False")
return v
@classmethod
def from_tuple(cls, result: Tuple[bool, Union[Any, str]]) -> "GuardrailResult":
"""Create a GuardrailResult from a validation tuple.
Args:
result: A tuple of (success, data) where data is either
the validated result or error message.
Returns:
GuardrailResult: A new instance with the tuple data.
"""
success, data = result
return cls(
success=success,
result=data if success else None,
error=data if not success else None
)