mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
2 Commits
devin/1760
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5ccdfdb1d3 | ||
|
|
01114168df |
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "0.203.1"
|
||||
__version__ = "0.203.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -30,7 +30,6 @@ def validate_jwt_token(
|
||||
algorithms=["RS256"],
|
||||
audience=audience,
|
||||
issuer=issuer,
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
|
||||
@@ -1,11 +1,10 @@
|
||||
import os
|
||||
import subprocess
|
||||
from enum import Enum
|
||||
|
||||
import click
|
||||
from packaging import version
|
||||
|
||||
from crewai.cli.utils import build_env_with_tool_repository_credentials, read_toml
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
|
||||
|
||||
@@ -56,22 +55,8 @@ def execute_command(crew_type: CrewType) -> None:
|
||||
"""
|
||||
command = ["uv", "run", "kickoff" if crew_type == CrewType.FLOW else "run_crew"]
|
||||
|
||||
env = os.environ.copy()
|
||||
try:
|
||||
pyproject_data = read_toml()
|
||||
sources = pyproject_data.get("tool", {}).get("uv", {}).get("sources", {})
|
||||
|
||||
for source_config in sources.values():
|
||||
if isinstance(source_config, dict):
|
||||
index = source_config.get("index")
|
||||
if index:
|
||||
index_env = build_env_with_tool_repository_credentials(index)
|
||||
env.update(index_env)
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
try:
|
||||
subprocess.run(command, capture_output=False, text=True, check=True, env=env) # noqa: S603
|
||||
subprocess.run(command, capture_output=False, text=True, check=True) # noqa: S603
|
||||
|
||||
except subprocess.CalledProcessError as e:
|
||||
handle_error(e, crew_type)
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.1,<1.0.0"
|
||||
"crewai[tools]>=0.203.0,<1.0.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.1,<1.0.0",
|
||||
"crewai[tools]>=0.203.0,<1.0.0",
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
|
||||
readme = "README.md"
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]>=0.203.1"
|
||||
"crewai[tools]>=0.203.0"
|
||||
]
|
||||
|
||||
[tool.crewai]
|
||||
|
||||
@@ -358,8 +358,7 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
try:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt, OSError, LookupError):
|
||||
# Handle all input-related errors silently
|
||||
except (EOFError, KeyboardInterrupt):
|
||||
result[0] = False
|
||||
|
||||
input_thread = threading.Thread(target=get_input, daemon=True)
|
||||
@@ -372,7 +371,6 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Suppress any warnings or errors and assume "no"
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
condition : Optional[Union[str, dict, Callable]], optional
|
||||
Defines when the start method should execute. Can be:
|
||||
- str: Name of a method that triggers this start
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this start
|
||||
Default is None, meaning unconditional start.
|
||||
|
||||
@@ -140,18 +140,13 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -177,7 +172,7 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the listener should execute. Can be:
|
||||
- str: Name of a method that triggers this listener
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this listener
|
||||
|
||||
Returns
|
||||
@@ -205,18 +200,13 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -243,7 +233,7 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the router should execute. Can be:
|
||||
- str: Name of a method that triggers this router
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this router
|
||||
|
||||
Returns
|
||||
@@ -276,18 +266,13 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -313,15 +298,14 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "OR", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "OR", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -333,22 +317,18 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_(and_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
return {"type": "OR", "conditions": processed_conditions}
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
@@ -364,15 +344,14 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "AND", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "AND", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -384,69 +363,18 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(and_(or_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
return {"type": "AND", "conditions": processed_conditions}
|
||||
|
||||
|
||||
def _normalize_condition(condition: str | dict | list) -> dict:
|
||||
"""Normalize a condition to standard format with 'conditions' key.
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name), dict (condition), or list
|
||||
|
||||
Returns:
|
||||
Normalized dict with 'type' and 'conditions' keys
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
if isinstance(condition, dict):
|
||||
if "conditions" in condition:
|
||||
return condition
|
||||
if "methods" in condition:
|
||||
return {"type": condition["type"], "conditions": condition["methods"]}
|
||||
return condition
|
||||
if isinstance(condition, list):
|
||||
return {"type": "OR", "conditions": condition}
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
|
||||
|
||||
def _extract_all_methods(condition: str | dict | list) -> list[str]:
|
||||
"""Extract all method names from a condition (including nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list
|
||||
|
||||
Returns:
|
||||
List of all method names in the condition tree
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
methods = []
|
||||
for sub_cond in normalized.get("conditions", []):
|
||||
methods.extend(_extract_all_methods(sub_cond))
|
||||
return methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods(item))
|
||||
return methods
|
||||
return []
|
||||
return {"type": "AND", "methods": methods}
|
||||
|
||||
|
||||
class FlowMeta(type):
|
||||
@@ -474,10 +402,7 @@ class FlowMeta(type):
|
||||
if hasattr(attr_value, "__trigger_methods__"):
|
||||
methods = attr_value.__trigger_methods__
|
||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||
if hasattr(attr_value, "__trigger_condition__"):
|
||||
listeners[attr_name] = attr_value.__trigger_condition__
|
||||
else:
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__is_router__")
|
||||
@@ -897,7 +822,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
@@ -1162,16 +1086,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_data = self._listeners[method_name]
|
||||
should_trigger = False
|
||||
if isinstance(condition_data, tuple):
|
||||
_, trigger_methods = condition_data
|
||||
should_trigger = current_trigger in trigger_methods
|
||||
elif isinstance(condition_data, dict):
|
||||
all_methods = _extract_all_methods(condition_data)
|
||||
should_trigger = current_trigger in all_methods
|
||||
|
||||
if should_trigger:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
@@ -1181,51 +1099,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str | dict, trigger_method: str, listener_name: str
|
||||
) -> bool:
|
||||
"""Recursively evaluate a condition (simple or nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name) or dict (nested condition)
|
||||
trigger_method: The method that just completed
|
||||
listener_name: Name of the listener being evaluated
|
||||
|
||||
Returns:
|
||||
True if the condition is satisfied, False otherwise
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return condition == trigger_method
|
||||
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", "OR")
|
||||
sub_conditions = normalized.get("conditions", [])
|
||||
|
||||
if cond_type == "OR":
|
||||
return any(
|
||||
self._evaluate_condition(sub_cond, trigger_method, listener_name)
|
||||
for sub_cond in sub_conditions
|
||||
)
|
||||
|
||||
if cond_type == "AND":
|
||||
pending_key = f"{listener_name}:{id(condition)}"
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
all_methods = set(_extract_all_methods(condition))
|
||||
self._pending_and_listeners[pending_key] = all_methods
|
||||
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> list[str]:
|
||||
@@ -1233,7 +1106,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow. Supports nested conditions.
|
||||
which methods should be executed next in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1250,13 +1123,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both OR and AND conditions, including nested combinations
|
||||
- Handles both OR and AND conditions:
|
||||
* OR: Triggers if any condition is met
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""
|
||||
triggered = []
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
|
||||
if router_only != is_router:
|
||||
@@ -1265,29 +1139,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if isinstance(condition_data, tuple):
|
||||
condition_type, methods = condition_data
|
||||
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(
|
||||
trigger_method
|
||||
)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
triggered.append(listener_name)
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
elif isinstance(condition_data, dict):
|
||||
if self._evaluate_condition(
|
||||
condition_data, trigger_method, listener_name
|
||||
):
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
# Initialize pending methods for this listener if not already done
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
# Remove the trigger method from pending methods
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
# All required methods have been executed
|
||||
triggered.append(listener_name)
|
||||
# Reset pending methods for this listener
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
return triggered
|
||||
|
||||
@@ -1350,7 +1218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import uuid
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from concurrent.futures import Future
|
||||
from copy import copy as shallow_copy
|
||||
from copy import copy
|
||||
from hashlib import md5
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
@@ -336,12 +336,6 @@ class Task(BaseModel):
|
||||
setattr(self, key, value)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_tools(self):
|
||||
"""Check if the tools are set."""
|
||||
if not self.tools and self.agent and self.agent.tools:
|
||||
self.tools.extend(self.agent.tools)
|
||||
return self
|
||||
|
||||
@model_validator(mode="after")
|
||||
def check_output(self):
|
||||
@@ -672,9 +666,7 @@ Follow these guidelines:
|
||||
copied_data = {k: v for k, v in copied_data.items() if v is not None}
|
||||
|
||||
cloned_context = (
|
||||
self.context
|
||||
if self.context is NOT_SPECIFIED
|
||||
else [task_mapping[context_task.key] for context_task in self.context]
|
||||
[task_mapping[context_task.key] for context_task in self.context]
|
||||
if isinstance(self.context, list)
|
||||
else None
|
||||
)
|
||||
@@ -683,7 +675,7 @@ Follow these guidelines:
|
||||
return next((agent for agent in agents if agent.role == role), None)
|
||||
|
||||
cloned_agent = get_agent_by_role(self.agent.role) if self.agent else None
|
||||
cloned_tools = shallow_copy(self.tools) if self.tools else []
|
||||
cloned_tools = copy(self.tools) if self.tools else []
|
||||
|
||||
return self.__class__(
|
||||
**copied_data,
|
||||
|
||||
@@ -2,18 +2,11 @@ from __future__ import annotations
|
||||
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from typing import TYPE_CHECKING, Any, Final, TypedDict, Union, get_args, get_origin
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
from typing_extensions import Unpack
|
||||
|
||||
if sys.version_info >= (3, 10):
|
||||
import types
|
||||
UnionTypes = (Union, types.UnionType)
|
||||
else:
|
||||
UnionTypes = (Union,)
|
||||
|
||||
from crewai.agents.agent_builder.utilities.base_output_converter import OutputConverter
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -435,21 +428,12 @@ def generate_model_description(model: type[BaseModel]) -> str:
|
||||
origin = get_origin(field_type)
|
||||
args = get_args(field_type)
|
||||
|
||||
if origin in UnionTypes or (origin is None and len(args) > 0):
|
||||
if origin is Union or (origin is None and len(args) > 0):
|
||||
# Handle both Union and the new '|' syntax
|
||||
non_none_args = [arg for arg in args if arg is not type(None)]
|
||||
has_none = type(None) in args
|
||||
|
||||
if has_none:
|
||||
# It's an Optional type
|
||||
if len(non_none_args) == 1:
|
||||
return f"Optional[{describe_field(non_none_args[0])}]"
|
||||
# Union with None and multiple other types
|
||||
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||
else:
|
||||
if len(non_none_args) == 1:
|
||||
return describe_field(non_none_args[0])
|
||||
return f"Union[{', '.join(describe_field(arg) for arg in args)}]"
|
||||
if len(non_none_args) == 1:
|
||||
return f"Optional[{describe_field(non_none_args[0])}]"
|
||||
return f"Optional[Union[{', '.join(describe_field(arg) for arg in non_none_args)}]]"
|
||||
if origin is list:
|
||||
return f"List[{describe_field(args[0])}]"
|
||||
if origin is dict:
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsWrite
|
||||
from typing import Final, Literal, NamedTuple
|
||||
|
||||
PrinterColor = Literal[
|
||||
"purple",
|
||||
@@ -59,22 +54,13 @@ class Printer:
|
||||
|
||||
@staticmethod
|
||||
def print(
|
||||
content: str | list[ColoredText],
|
||||
color: PrinterColor | None = None,
|
||||
sep: str | None = " ",
|
||||
end: str | None = "\n",
|
||||
file: SupportsWrite[str] | None = None,
|
||||
flush: Literal[False] = False,
|
||||
content: str | list[ColoredText], color: PrinterColor | None = None
|
||||
) -> None:
|
||||
"""Prints content to the console with optional color formatting.
|
||||
|
||||
Args:
|
||||
content: Either a string or a list of ColoredText objects for multicolor output.
|
||||
color: Optional color for the text when content is a string. Ignored when content is a list.
|
||||
sep: Separator to use between the text and color.
|
||||
end: String appended after the last value.
|
||||
file: A file-like object (stream); defaults to the current sys.stdout.
|
||||
flush: Whether to forcibly flush the stream.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
content = [ColoredText(content, color)]
|
||||
@@ -82,9 +68,5 @@ class Printer:
|
||||
"".join(
|
||||
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
|
||||
for c in content
|
||||
),
|
||||
sep=sep,
|
||||
end=end,
|
||||
file=file,
|
||||
flush=flush,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import jwt
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai.cli.authentication.utils import validate_jwt_token
|
||||
|
||||
@@ -17,22 +17,19 @@ class TestUtils(unittest.TestCase):
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token=jwt_token,
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
)
|
||||
|
||||
mock_jwt.decode.assert_called_with(
|
||||
jwt_token,
|
||||
"aaaaa.bbbbbb.cccccc",
|
||||
"mock_signing_key",
|
||||
algorithms=["RS256"],
|
||||
audience="app_id_xxxx",
|
||||
issuer="https://mock_issuer",
|
||||
leeway=10.0,
|
||||
options={
|
||||
"verify_signature": True,
|
||||
"verify_exp": True,
|
||||
@@ -46,9 +43,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_expired(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -56,9 +53,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_audience(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -66,9 +63,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_issuer(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -78,9 +75,9 @@ class TestUtils(unittest.TestCase):
|
||||
self, mock_jwt, mock_pyjwkclient
|
||||
):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -88,9 +85,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_jwks_error(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -98,9 +95,9 @@ class TestUtils(unittest.TestCase):
|
||||
|
||||
def test_validate_jwt_token_invalid_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
with self.assertRaises(Exception):
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
|
||||
@@ -200,7 +200,7 @@ def test_async_task_cannot_include_sequential_async_tasks_in_context(
|
||||
# This should raise an error because task2 is async and has task1 in its context without a sync task in between
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Task 'Task 2' is asynchronous and cannot include other sequential asynchronous tasks in its context.",
|
||||
match=r"Task 'Task 2' is asynchronous and cannot include other sequential asynchronous tasks in its context.",
|
||||
):
|
||||
Crew(tasks=[task1, task2, task3, task4, task5], agents=[researcher, writer])
|
||||
|
||||
@@ -238,7 +238,7 @@ def test_context_no_future_tasks(researcher, writer):
|
||||
# This should raise an error because task1 has a context dependency on a future task (task4)
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Task 'Task 1' has a context dependency on a future task 'Task 4', which is not allowed.",
|
||||
match=r"Task 'Task 1' has a context dependency on a future task 'Task 4', which is not allowed.",
|
||||
):
|
||||
Crew(tasks=[task1, task2, task3, task4], agents=[researcher, writer])
|
||||
|
||||
@@ -3339,7 +3339,7 @@ def test_replay_with_invalid_task_id():
|
||||
):
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Task with id bf5b09c9-69bd-4eb8-be12-f9e5bae31c2d not found in the crew's tasks.",
|
||||
match=r"Task with id bf5b09c9-69bd-4eb8-be12-f9e5bae31c2d not found in the crew's tasks.",
|
||||
):
|
||||
crew.replay("bf5b09c9-69bd-4eb8-be12-f9e5bae31c2d")
|
||||
|
||||
@@ -3814,6 +3814,46 @@ def test_fetch_inputs():
|
||||
)
|
||||
|
||||
|
||||
def test_hierarchical_crew_does_not_propagate_agent_tools_to_manager():
|
||||
"""
|
||||
Test that in hierarchical crews, manager agent doesn't inherit task agents' tools.
|
||||
This verifies that the check_tools validator doesn't pollute task.tools at creation time.
|
||||
Fixes issue #3679: https://github.com/crewAIInc/crewAI/issues/3679
|
||||
"""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def agent_specific_tool() -> str:
|
||||
"""A tool that should only be available to the specific agent."""
|
||||
return "agent specific result"
|
||||
|
||||
agent_with_tools = Agent(
|
||||
role="Specialist",
|
||||
goal="Do specialized work with custom tools",
|
||||
backstory="You are a specialist with specific tools",
|
||||
tools=[agent_specific_tool],
|
||||
allow_delegation=False,
|
||||
)
|
||||
|
||||
# Create a task with an agent that has tools, but don't assign tools to the task
|
||||
task = Task(
|
||||
description="Perform a specialized task",
|
||||
expected_output="Task result",
|
||||
agent=agent_with_tools,
|
||||
)
|
||||
|
||||
Crew(
|
||||
agents=[agent_with_tools],
|
||||
tasks=[task],
|
||||
process=Process.hierarchical,
|
||||
manager_llm="gpt-4o",
|
||||
)
|
||||
|
||||
# Verify that task.tools is empty (not populated with agent's tools)
|
||||
assert task.tools == []
|
||||
assert len(task.tools) == 0
|
||||
|
||||
|
||||
def test_task_tools_preserve_code_execution_tools():
|
||||
"""
|
||||
Test that task tools don't override code execution tools when allow_code_execution=True
|
||||
|
||||
@@ -6,15 +6,15 @@ from datetime import datetime
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
FlowPlotEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state.sent is False
|
||||
assert getattr(received_events[3].state, "sent") is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state.sent is True
|
||||
assert getattr(received_events[4].state, "sent") is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
@@ -894,75 +894,3 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
Reproduces bug from issue #3719 where nested conditions are flattened,
|
||||
causing premature execution.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class NestedConditionFlow(Flow):
|
||||
@start()
|
||||
def method_1(self):
|
||||
execution_order.append("method_1")
|
||||
|
||||
@listen(method_1)
|
||||
def method_2(self):
|
||||
execution_order.append("method_2")
|
||||
|
||||
@router(method_2)
|
||||
def method_3(self):
|
||||
execution_order.append("method_3")
|
||||
# Choose b_condition path
|
||||
return "b_condition"
|
||||
|
||||
@listen("b_condition")
|
||||
def method_5(self):
|
||||
execution_order.append("method_5")
|
||||
|
||||
@listen(method_5)
|
||||
async def method_4(self):
|
||||
execution_order.append("method_4")
|
||||
|
||||
@listen(or_("a_condition", "b_condition"))
|
||||
async def method_6(self):
|
||||
execution_order.append("method_6")
|
||||
|
||||
@listen(
|
||||
or_(
|
||||
and_("a_condition", method_6),
|
||||
and_(method_6, method_4),
|
||||
)
|
||||
)
|
||||
def method_7(self):
|
||||
execution_order.append("method_7")
|
||||
|
||||
@listen(method_7)
|
||||
async def method_8(self):
|
||||
execution_order.append("method_8")
|
||||
|
||||
flow = NestedConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Verify execution happened
|
||||
assert "method_1" in execution_order
|
||||
assert "method_2" in execution_order
|
||||
assert "method_3" in execution_order
|
||||
assert "method_5" in execution_order
|
||||
assert "method_4" in execution_order
|
||||
assert "method_6" in execution_order
|
||||
assert "method_7" in execution_order
|
||||
assert "method_8" in execution_order
|
||||
|
||||
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
|
||||
# Since b_condition was returned, method_6 triggers on b_condition
|
||||
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
|
||||
# The second condition (method_6 AND method_4) should be the one that triggers
|
||||
assert execution_order.index("method_7") > execution_order.index("method_6")
|
||||
assert execution_order.index("method_7") > execution_order.index("method_4")
|
||||
|
||||
# method_8 should execute after method_7
|
||||
assert execution_order.index("method_8") > execution_order.index("method_7")
|
||||
|
||||
@@ -20,11 +20,13 @@ from crewai.utilities.string_utils import interpolate_only
|
||||
|
||||
|
||||
def test_task_tool_reflect_agent_tools():
|
||||
"""Test that agent tools are available during task execution via crew fallback logic."""
|
||||
from crewai.tools import tool
|
||||
|
||||
@tool
|
||||
def fake_tool() -> None:
|
||||
def fake_tool() -> str:
|
||||
"Fake tool"
|
||||
return "result"
|
||||
|
||||
researcher = Agent(
|
||||
role="Researcher",
|
||||
@@ -40,7 +42,9 @@ def test_task_tool_reflect_agent_tools():
|
||||
agent=researcher,
|
||||
)
|
||||
|
||||
assert task.tools == [fake_tool]
|
||||
assert task.tools == []
|
||||
|
||||
assert researcher.tools == [fake_tool]
|
||||
|
||||
|
||||
def test_task_tool_takes_precedence_over_agent_tools():
|
||||
@@ -1635,48 +1639,3 @@ def test_task_interpolation_with_hyphens():
|
||||
assert "say hello world" in task.prompt()
|
||||
|
||||
assert result.raw == "Hello, World!"
|
||||
|
||||
|
||||
def test_task_copy_with_none_context():
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
context=None
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is None
|
||||
assert new_task.context is None
|
||||
|
||||
|
||||
def test_task_copy_with_not_specified_context():
|
||||
from crewai.utilities.constants import NOT_SPECIFIED
|
||||
original_task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
)
|
||||
|
||||
new_task = original_task.copy(agents=[], task_mapping={})
|
||||
assert original_task.context is NOT_SPECIFIED
|
||||
assert new_task.context is NOT_SPECIFIED
|
||||
|
||||
|
||||
def test_task_copy_with_list_context():
|
||||
"""Test that copying a task with list context works correctly."""
|
||||
task1 = Task(
|
||||
description="Task 1",
|
||||
expected_output="Output 1"
|
||||
)
|
||||
task2 = Task(
|
||||
description="Task 2",
|
||||
expected_output="Output 2",
|
||||
context=[task1]
|
||||
)
|
||||
|
||||
task_mapping = {task1.key: task1}
|
||||
|
||||
copied_task2 = task2.copy(agents=[], task_mapping=task_mapping)
|
||||
|
||||
assert isinstance(copied_task2.context, list)
|
||||
assert len(copied_task2.context) == 1
|
||||
assert copied_task2.context[0] is task1
|
||||
|
||||
@@ -1,199 +0,0 @@
|
||||
"""Test Union type support in Pydantic outputs."""
|
||||
import json
|
||||
from typing import Union
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.converter import (
|
||||
convert_to_model,
|
||||
generate_model_description,
|
||||
)
|
||||
from crewai.utilities.pydantic_schema_parser import PydanticSchemaParser
|
||||
|
||||
|
||||
class SuccessData(BaseModel):
|
||||
"""Model for successful response."""
|
||||
status: str
|
||||
result: str
|
||||
value: int
|
||||
|
||||
|
||||
class ErrorMessage(BaseModel):
|
||||
"""Model for error response."""
|
||||
status: str
|
||||
error: str
|
||||
code: int
|
||||
|
||||
|
||||
class ResponseWithUnion(BaseModel):
|
||||
"""Model with Union type field."""
|
||||
response: Union[SuccessData, ErrorMessage]
|
||||
|
||||
|
||||
class DirectUnionModel(BaseModel):
|
||||
"""Model with direct Union type."""
|
||||
data: Union[str, int, dict]
|
||||
|
||||
|
||||
class MultiUnionModel(BaseModel):
|
||||
"""Model with multiple Union types."""
|
||||
field1: Union[str, int]
|
||||
field2: Union[SuccessData, ErrorMessage, None]
|
||||
|
||||
|
||||
def test_convert_to_model_with_union_success_data():
|
||||
"""Test converting JSON to a model with Union type (SuccessData variant)."""
|
||||
result = json.dumps({
|
||||
"response": {
|
||||
"status": "success",
|
||||
"result": "Operation completed",
|
||||
"value": 42
|
||||
}
|
||||
})
|
||||
|
||||
output = convert_to_model(result, ResponseWithUnion, None, None)
|
||||
assert isinstance(output, ResponseWithUnion)
|
||||
assert isinstance(output.response, SuccessData)
|
||||
assert output.response.status == "success"
|
||||
assert output.response.result == "Operation completed"
|
||||
assert output.response.value == 42
|
||||
|
||||
|
||||
def test_convert_to_model_with_union_error_message():
|
||||
"""Test converting JSON to a model with Union type (ErrorMessage variant)."""
|
||||
result = json.dumps({
|
||||
"response": {
|
||||
"status": "error",
|
||||
"error": "Something went wrong",
|
||||
"code": 500
|
||||
}
|
||||
})
|
||||
|
||||
output = convert_to_model(result, ResponseWithUnion, None, None)
|
||||
assert isinstance(output, ResponseWithUnion)
|
||||
assert isinstance(output.response, ErrorMessage)
|
||||
assert output.response.status == "error"
|
||||
assert output.response.error == "Something went wrong"
|
||||
assert output.response.code == 500
|
||||
|
||||
|
||||
def test_convert_to_model_with_direct_union_string():
|
||||
"""Test converting JSON to a model with direct Union type (string variant)."""
|
||||
result = json.dumps({"data": "hello world"})
|
||||
|
||||
output = convert_to_model(result, DirectUnionModel, None, None)
|
||||
assert isinstance(output, DirectUnionModel)
|
||||
assert isinstance(output.data, str)
|
||||
assert output.data == "hello world"
|
||||
|
||||
|
||||
def test_convert_to_model_with_direct_union_int():
|
||||
"""Test converting JSON to a model with direct Union type (int variant)."""
|
||||
result = json.dumps({"data": 42})
|
||||
|
||||
output = convert_to_model(result, DirectUnionModel, None, None)
|
||||
assert isinstance(output, DirectUnionModel)
|
||||
assert isinstance(output.data, int)
|
||||
assert output.data == 42
|
||||
|
||||
|
||||
def test_convert_to_model_with_direct_union_dict():
|
||||
"""Test converting JSON to a model with direct Union type (dict variant)."""
|
||||
result = json.dumps({"data": {"key": "value", "number": 123}})
|
||||
|
||||
output = convert_to_model(result, DirectUnionModel, None, None)
|
||||
assert isinstance(output, DirectUnionModel)
|
||||
assert isinstance(output.data, dict)
|
||||
assert output.data == {"key": "value", "number": 123}
|
||||
|
||||
|
||||
def test_convert_to_model_with_multiple_unions():
|
||||
"""Test converting JSON to a model with multiple Union type fields."""
|
||||
result = json.dumps({
|
||||
"field1": "text",
|
||||
"field2": {
|
||||
"status": "success",
|
||||
"result": "Done",
|
||||
"value": 100
|
||||
}
|
||||
})
|
||||
|
||||
output = convert_to_model(result, MultiUnionModel, None, None)
|
||||
assert isinstance(output, MultiUnionModel)
|
||||
assert isinstance(output.field1, str)
|
||||
assert output.field1 == "text"
|
||||
assert isinstance(output.field2, SuccessData)
|
||||
assert output.field2.status == "success"
|
||||
|
||||
|
||||
def test_convert_to_model_with_optional_union_none():
|
||||
"""Test converting JSON to a model with optional Union type (None variant)."""
|
||||
result = json.dumps({
|
||||
"field1": 42,
|
||||
"field2": None
|
||||
})
|
||||
|
||||
output = convert_to_model(result, MultiUnionModel, None, None)
|
||||
assert isinstance(output, MultiUnionModel)
|
||||
assert isinstance(output.field1, int)
|
||||
assert output.field1 == 42
|
||||
assert output.field2 is None
|
||||
|
||||
|
||||
def test_generate_model_description_with_union():
|
||||
"""Test that generate_model_description handles Union types correctly."""
|
||||
description = generate_model_description(ResponseWithUnion)
|
||||
|
||||
assert "Union" in description
|
||||
assert "Optional" not in description
|
||||
assert "status" in description
|
||||
print(f"Generated description:\n{description}")
|
||||
|
||||
|
||||
def test_generate_model_description_with_direct_union():
|
||||
"""Test that generate_model_description handles direct Union types correctly."""
|
||||
description = generate_model_description(DirectUnionModel)
|
||||
|
||||
assert "Union" in description
|
||||
assert "Optional" not in description
|
||||
assert "str" in description and "int" in description and "dict" in description
|
||||
print(f"Generated description:\n{description}")
|
||||
|
||||
|
||||
def test_pydantic_schema_parser_with_union():
|
||||
"""Test that PydanticSchemaParser handles Union types correctly."""
|
||||
parser = PydanticSchemaParser(model=ResponseWithUnion)
|
||||
schema = parser.get_schema()
|
||||
|
||||
assert "Union" in schema or "SuccessData" in schema or "ErrorMessage" in schema
|
||||
print(f"Generated schema:\n{schema}")
|
||||
|
||||
|
||||
def test_pydantic_schema_parser_with_direct_union():
|
||||
"""Test that PydanticSchemaParser handles direct Union types correctly."""
|
||||
parser = PydanticSchemaParser(model=DirectUnionModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
assert "Union" in schema or ("str" in schema and "int" in schema and "dict" in schema)
|
||||
print(f"Generated schema:\n{schema}")
|
||||
|
||||
|
||||
def test_pydantic_schema_parser_with_optional_union():
|
||||
"""Test that PydanticSchemaParser handles Optional Union types correctly."""
|
||||
parser = PydanticSchemaParser(model=MultiUnionModel)
|
||||
schema = parser.get_schema()
|
||||
|
||||
assert "Union" in schema or "Optional" in schema
|
||||
print(f"Generated schema:\n{schema}")
|
||||
|
||||
|
||||
def test_generate_model_description_with_optional_union():
|
||||
"""Test that generate_model_description correctly wraps Optional Union types."""
|
||||
description = generate_model_description(MultiUnionModel)
|
||||
|
||||
assert "field1" in description
|
||||
assert "field2" in description
|
||||
assert "Optional" in description
|
||||
print(f"Generated description:\n{description}")
|
||||
@@ -596,5 +596,5 @@ def test_generate_model_description_union_field():
|
||||
field: int | str | None
|
||||
|
||||
description = generate_model_description(UnionModel)
|
||||
expected_description = '{\n "field": Optional[Union[int, str]]\n}'
|
||||
expected_description = '{\n "field": int | str | None\n}'
|
||||
assert description == expected_description
|
||||
|
||||
Reference in New Issue
Block a user