mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-20 09:52:36 +00:00
Compare commits
2 Commits
devin/1760
...
devin/1760
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1ae3a003b6 | ||
|
|
fc4b0dd923 |
@@ -283,30 +283,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
|
||||
)
|
||||
|
||||
@field_validator("embedder", mode="before")
|
||||
@classmethod
|
||||
def normalize_embedder_config(
|
||||
cls, v: dict[str, Any] | None
|
||||
) -> dict[str, Any] | None:
|
||||
"""Normalize embedder config to support both flat and nested formats.
|
||||
|
||||
Args:
|
||||
v: The embedder config to be normalized.
|
||||
|
||||
Returns:
|
||||
The normalized embedder config with nested structure.
|
||||
"""
|
||||
if v is None or not isinstance(v, dict):
|
||||
return v
|
||||
|
||||
if "provider" in v and "config" not in v:
|
||||
provider = v["provider"]
|
||||
config_fields = {k: val for k, val in v.items() if k != "provider"}
|
||||
if config_fields:
|
||||
return {"provider": provider, "config": config_fields}
|
||||
|
||||
return v
|
||||
|
||||
@field_validator("config", mode="before")
|
||||
@classmethod
|
||||
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:
|
||||
|
||||
@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
|
||||
from crewai.flow.persistence.base import FlowPersistence
|
||||
from crewai.flow.types import FlowExecutionData
|
||||
from crewai.flow.utils import get_possible_return_constants
|
||||
from crewai.utilities.printer import Printer, PrinterColor
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
condition : Optional[Union[str, dict, Callable]], optional
|
||||
Defines when the start method should execute. Can be:
|
||||
- str: Name of a method that triggers this start
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this start
|
||||
Default is None, meaning unconditional start.
|
||||
|
||||
@@ -140,18 +140,13 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -177,7 +172,7 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the listener should execute. Can be:
|
||||
- str: Name of a method that triggers this listener
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this listener
|
||||
|
||||
Returns
|
||||
@@ -205,18 +200,13 @@ def listen(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -243,7 +233,7 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
condition : Union[str, dict, Callable]
|
||||
Specifies when the router should execute. Can be:
|
||||
- str: Name of a method that triggers this router
|
||||
- dict: Result from or_() or and_(), including nested conditions
|
||||
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
|
||||
- Callable: A method reference that triggers this router
|
||||
|
||||
Returns
|
||||
@@ -276,18 +266,13 @@ def router(condition: str | dict | Callable) -> Callable:
|
||||
if isinstance(condition, str):
|
||||
func.__trigger_methods__ = [condition]
|
||||
func.__condition_type__ = "OR"
|
||||
elif isinstance(condition, dict) and "type" in condition:
|
||||
if "conditions" in condition:
|
||||
func.__trigger_condition__ = condition
|
||||
func.__trigger_methods__ = _extract_all_methods(condition)
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif "methods" in condition:
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
else:
|
||||
raise ValueError(
|
||||
"Condition dict must contain 'conditions' or 'methods'"
|
||||
)
|
||||
elif (
|
||||
isinstance(condition, dict)
|
||||
and "type" in condition
|
||||
and "methods" in condition
|
||||
):
|
||||
func.__trigger_methods__ = condition["methods"]
|
||||
func.__condition_type__ = condition["type"]
|
||||
elif callable(condition) and hasattr(condition, "__name__"):
|
||||
func.__trigger_methods__ = [condition.__name__]
|
||||
func.__condition_type__ = "OR"
|
||||
@@ -313,15 +298,14 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "OR", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "OR", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -333,22 +317,18 @@ def or_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(or_("success", "timeout"))
|
||||
>>> def handle_completion(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(or_(and_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in or_()")
|
||||
return {"type": "OR", "conditions": processed_conditions}
|
||||
return {"type": "OR", "methods": methods}
|
||||
|
||||
|
||||
def and_(*conditions: str | dict | Callable) -> dict:
|
||||
@@ -364,15 +344,14 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
*conditions : Union[str, dict, Callable]
|
||||
Variable number of conditions that can be:
|
||||
- str: Method names
|
||||
- dict: Existing condition dictionaries (nested conditions)
|
||||
- dict: Existing condition dictionaries
|
||||
- Callable: Method references
|
||||
|
||||
Returns
|
||||
-------
|
||||
dict
|
||||
A condition dictionary with format:
|
||||
{"type": "AND", "conditions": list_of_conditions}
|
||||
where each condition can be a string (method name) or a nested dict
|
||||
{"type": "AND", "methods": list_of_method_names}
|
||||
|
||||
Raises
|
||||
------
|
||||
@@ -384,69 +363,18 @@ def and_(*conditions: str | dict | Callable) -> dict:
|
||||
>>> @listen(and_("validated", "processed"))
|
||||
>>> def handle_complete_data(self):
|
||||
... pass
|
||||
|
||||
>>> @listen(and_(or_("step1", "step2"), "step3"))
|
||||
>>> def handle_nested(self):
|
||||
... pass
|
||||
"""
|
||||
processed_conditions: list[str | dict[str, Any]] = []
|
||||
methods = []
|
||||
for condition in conditions:
|
||||
if isinstance(condition, dict):
|
||||
processed_conditions.append(condition)
|
||||
if isinstance(condition, dict) and "methods" in condition:
|
||||
methods.extend(condition["methods"])
|
||||
elif isinstance(condition, str):
|
||||
processed_conditions.append(condition)
|
||||
methods.append(condition)
|
||||
elif callable(condition):
|
||||
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
|
||||
methods.append(getattr(condition, "__name__", repr(condition)))
|
||||
else:
|
||||
raise ValueError("Invalid condition in and_()")
|
||||
return {"type": "AND", "conditions": processed_conditions}
|
||||
|
||||
|
||||
def _normalize_condition(condition: str | dict | list) -> dict:
|
||||
"""Normalize a condition to standard format with 'conditions' key.
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name), dict (condition), or list
|
||||
|
||||
Returns:
|
||||
Normalized dict with 'type' and 'conditions' keys
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
if isinstance(condition, dict):
|
||||
if "conditions" in condition:
|
||||
return condition
|
||||
if "methods" in condition:
|
||||
return {"type": condition["type"], "conditions": condition["methods"]}
|
||||
return condition
|
||||
if isinstance(condition, list):
|
||||
return {"type": "OR", "conditions": condition}
|
||||
return {"type": "OR", "conditions": [condition]}
|
||||
|
||||
|
||||
def _extract_all_methods(condition: str | dict | list) -> list[str]:
|
||||
"""Extract all method names from a condition (including nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string, dict, or list
|
||||
|
||||
Returns:
|
||||
List of all method names in the condition tree
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return [condition]
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
methods = []
|
||||
for sub_cond in normalized.get("conditions", []):
|
||||
methods.extend(_extract_all_methods(sub_cond))
|
||||
return methods
|
||||
if isinstance(condition, list):
|
||||
methods = []
|
||||
for item in condition:
|
||||
methods.extend(_extract_all_methods(item))
|
||||
return methods
|
||||
return []
|
||||
return {"type": "AND", "methods": methods}
|
||||
|
||||
|
||||
class FlowMeta(type):
|
||||
@@ -474,10 +402,7 @@ class FlowMeta(type):
|
||||
if hasattr(attr_value, "__trigger_methods__"):
|
||||
methods = attr_value.__trigger_methods__
|
||||
condition_type = getattr(attr_value, "__condition_type__", "OR")
|
||||
if hasattr(attr_value, "__trigger_condition__"):
|
||||
listeners[attr_name] = attr_value.__trigger_condition__
|
||||
else:
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
listeners[attr_name] = (condition_type, methods)
|
||||
|
||||
if (
|
||||
hasattr(attr_value, "__is_router__")
|
||||
@@ -897,7 +822,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
# Clear completed methods and outputs for a fresh start
|
||||
self._completed_methods.clear()
|
||||
self._method_outputs.clear()
|
||||
self._pending_and_listeners.clear()
|
||||
else:
|
||||
# We're restoring from persistence, set the flag
|
||||
self._is_execution_resuming = True
|
||||
@@ -1162,16 +1086,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_data = self._listeners[method_name]
|
||||
should_trigger = False
|
||||
if isinstance(condition_data, tuple):
|
||||
_, trigger_methods = condition_data
|
||||
should_trigger = current_trigger in trigger_methods
|
||||
elif isinstance(condition_data, dict):
|
||||
all_methods = _extract_all_methods(condition_data)
|
||||
should_trigger = current_trigger in all_methods
|
||||
|
||||
if should_trigger:
|
||||
condition_type, trigger_methods = self._listeners[
|
||||
method_name
|
||||
]
|
||||
if current_trigger in trigger_methods:
|
||||
# Only execute if this is a cycle (method was already completed)
|
||||
if method_name in self._completed_methods:
|
||||
# For router-triggered start methods in cycles, temporarily clear resumption flag
|
||||
@@ -1181,51 +1099,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
|
||||
def _evaluate_condition(
|
||||
self, condition: str | dict, trigger_method: str, listener_name: str
|
||||
) -> bool:
|
||||
"""Recursively evaluate a condition (simple or nested).
|
||||
|
||||
Args:
|
||||
condition: Can be a string (method name) or dict (nested condition)
|
||||
trigger_method: The method that just completed
|
||||
listener_name: Name of the listener being evaluated
|
||||
|
||||
Returns:
|
||||
True if the condition is satisfied, False otherwise
|
||||
"""
|
||||
if isinstance(condition, str):
|
||||
return condition == trigger_method
|
||||
|
||||
if isinstance(condition, dict):
|
||||
normalized = _normalize_condition(condition)
|
||||
cond_type = normalized.get("type", "OR")
|
||||
sub_conditions = normalized.get("conditions", [])
|
||||
|
||||
if cond_type == "OR":
|
||||
return any(
|
||||
self._evaluate_condition(sub_cond, trigger_method, listener_name)
|
||||
for sub_cond in sub_conditions
|
||||
)
|
||||
|
||||
if cond_type == "AND":
|
||||
pending_key = f"{listener_name}:{id(condition)}"
|
||||
|
||||
if pending_key not in self._pending_and_listeners:
|
||||
all_methods = set(_extract_all_methods(condition))
|
||||
self._pending_and_listeners[pending_key] = all_methods
|
||||
|
||||
if trigger_method in self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners[pending_key].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[pending_key]:
|
||||
self._pending_and_listeners.pop(pending_key, None)
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
return False
|
||||
|
||||
def _find_triggered_methods(
|
||||
self, trigger_method: str, router_only: bool
|
||||
) -> list[str]:
|
||||
@@ -1233,7 +1106,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
Finds all methods that should be triggered based on conditions.
|
||||
|
||||
This internal method evaluates both OR and AND conditions to determine
|
||||
which methods should be executed next in the flow. Supports nested conditions.
|
||||
which methods should be executed next in the flow.
|
||||
|
||||
Parameters
|
||||
----------
|
||||
@@ -1250,13 +1123,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
|
||||
Notes
|
||||
-----
|
||||
- Handles both OR and AND conditions, including nested combinations
|
||||
- Handles both OR and AND conditions:
|
||||
* OR: Triggers if any condition is met
|
||||
* AND: Triggers only when all conditions are met
|
||||
- Maintains state for AND conditions using _pending_and_listeners
|
||||
- Separates router and normal listener evaluation
|
||||
"""
|
||||
triggered = []
|
||||
|
||||
for listener_name, condition_data in self._listeners.items():
|
||||
for listener_name, (condition_type, methods) in self._listeners.items():
|
||||
is_router = listener_name in self._routers
|
||||
|
||||
if router_only != is_router:
|
||||
@@ -1265,29 +1139,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
if not router_only and listener_name in self._start_methods:
|
||||
continue
|
||||
|
||||
if isinstance(condition_data, tuple):
|
||||
condition_type, methods = condition_data
|
||||
|
||||
if condition_type == "OR":
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(
|
||||
trigger_method
|
||||
)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
triggered.append(listener_name)
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
elif isinstance(condition_data, dict):
|
||||
if self._evaluate_condition(
|
||||
condition_data, trigger_method, listener_name
|
||||
):
|
||||
if condition_type == "OR":
|
||||
# If the trigger_method matches any in methods, run this
|
||||
if trigger_method in methods:
|
||||
triggered.append(listener_name)
|
||||
elif condition_type == "AND":
|
||||
# Initialize pending methods for this listener if not already done
|
||||
if listener_name not in self._pending_and_listeners:
|
||||
self._pending_and_listeners[listener_name] = set(methods)
|
||||
# Remove the trigger method from pending methods
|
||||
if trigger_method in self._pending_and_listeners[listener_name]:
|
||||
self._pending_and_listeners[listener_name].discard(trigger_method)
|
||||
|
||||
if not self._pending_and_listeners[listener_name]:
|
||||
# All required methods have been executed
|
||||
triggered.append(listener_name)
|
||||
# Reset pending methods for this listener
|
||||
self._pending_and_listeners.pop(listener_name, None)
|
||||
|
||||
return triggered
|
||||
|
||||
@@ -1350,7 +1218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
raise
|
||||
|
||||
def _log_flow_event(
|
||||
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
|
||||
self, message: str, color: str = "yellow", level: str = "info"
|
||||
) -> None:
|
||||
"""Centralized logging method for flow events.
|
||||
|
||||
|
||||
@@ -299,6 +299,7 @@ class LLM(BaseLLM):
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
supports_function_calling: bool | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
@@ -325,6 +326,7 @@ class LLM(BaseLLM):
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self._supports_function_calling_override = supports_function_calling
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
@@ -1197,6 +1199,9 @@ class LLM(BaseLLM):
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
if self._supports_function_calling_override is not None:
|
||||
return self._supports_function_calling_override
|
||||
|
||||
try:
|
||||
provider = self._get_custom_llm_provider()
|
||||
return litellm.utils.supports_function_calling(
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Final
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
DEFAULT_SUPPORTS_FUNCTION_CALLING: Final[bool] = True
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -82,6 +83,14 @@ class BaseLLM(ABC):
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
return DEFAULT_SUPPORTS_FUNCTION_CALLING
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
|
||||
@@ -228,11 +228,8 @@ def build_embedder_from_dict(spec):
|
||||
"""Build an embedding function instance from a dictionary specification.
|
||||
|
||||
Args:
|
||||
spec: A dictionary with 'provider' and optionally 'config' keys.
|
||||
Supports two formats:
|
||||
|
||||
Nested format (recommended):
|
||||
{
|
||||
spec: A dictionary with 'provider' and 'config' keys.
|
||||
Example: {
|
||||
"provider": "openai",
|
||||
"config": {
|
||||
"api_key": "sk-...",
|
||||
@@ -240,13 +237,6 @@ def build_embedder_from_dict(spec):
|
||||
}
|
||||
}
|
||||
|
||||
Flat format (for backward compatibility):
|
||||
{
|
||||
"provider": "openai",
|
||||
"api_key": "sk-...",
|
||||
"model_name": "text-embedding-3-small"
|
||||
}
|
||||
|
||||
Returns:
|
||||
An instance of the appropriate embedding function.
|
||||
|
||||
@@ -276,10 +266,7 @@ def build_embedder_from_dict(spec):
|
||||
except (ImportError, AttributeError, ValueError) as e:
|
||||
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
|
||||
|
||||
if "config" in spec:
|
||||
provider_config = spec["config"]
|
||||
else:
|
||||
provider_config = {k: v for k, v in spec.items() if k != "provider"}
|
||||
provider_config = spec.get("config", {})
|
||||
|
||||
if provider_name == "custom" and "embedding_callable" not in provider_config:
|
||||
raise ValueError("Custom provider requires 'embedding_callable' in config")
|
||||
|
||||
@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
|
||||
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
|
||||
|
||||
|
||||
class GenerativeAiProviderSpec(TypedDict, total=False):
|
||||
class GenerativeAiProviderSpec(TypedDict):
|
||||
"""Google Generative AI provider specification."""
|
||||
|
||||
provider: Required[Literal["google-generativeai"]]
|
||||
provider: Literal["google-generativeai"]
|
||||
config: GenerativeAiProviderConfig
|
||||
|
||||
|
||||
|
||||
@@ -1,11 +1,6 @@
|
||||
"""Utility for colored console output."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from _typeshed import SupportsWrite
|
||||
from typing import Final, Literal, NamedTuple
|
||||
|
||||
PrinterColor = Literal[
|
||||
"purple",
|
||||
@@ -59,22 +54,13 @@ class Printer:
|
||||
|
||||
@staticmethod
|
||||
def print(
|
||||
content: str | list[ColoredText],
|
||||
color: PrinterColor | None = None,
|
||||
sep: str | None = " ",
|
||||
end: str | None = "\n",
|
||||
file: SupportsWrite[str] | None = None,
|
||||
flush: Literal[False] = False,
|
||||
content: str | list[ColoredText], color: PrinterColor | None = None
|
||||
) -> None:
|
||||
"""Prints content to the console with optional color formatting.
|
||||
|
||||
Args:
|
||||
content: Either a string or a list of ColoredText objects for multicolor output.
|
||||
color: Optional color for the text when content is a string. Ignored when content is a list.
|
||||
sep: Separator to use between the text and color.
|
||||
end: String appended after the last value.
|
||||
file: A file-like object (stream); defaults to the current sys.stdout.
|
||||
flush: Whether to forcibly flush the stream.
|
||||
"""
|
||||
if isinstance(content, str):
|
||||
content = [ColoredText(content, color)]
|
||||
@@ -82,9 +68,5 @@ class Printer:
|
||||
"".join(
|
||||
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
|
||||
for c in content
|
||||
),
|
||||
sep=sep,
|
||||
end=end,
|
||||
file=file,
|
||||
flush=flush,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -242,61 +242,3 @@ class TestEmbeddingFactory:
|
||||
mock_build_from_provider.assert_called_once_with(mock_provider)
|
||||
assert result == mock_embedding_function
|
||||
mock_import.assert_not_called()
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
|
||||
"""Test building Google Generative AI embedder with nested config format."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||
|
||||
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
|
||||
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
|
||||
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
|
||||
mock_provider_class = MagicMock()
|
||||
mock_provider_instance = MagicMock()
|
||||
mock_embedding_function = MagicMock()
|
||||
|
||||
mock_import.return_value = mock_provider_class
|
||||
mock_provider_class.return_value = mock_provider_instance
|
||||
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
|
||||
|
||||
config = {
|
||||
"provider": "google-generativeai",
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
}
|
||||
|
||||
build_embedder(config)
|
||||
|
||||
mock_import.assert_called_once_with(
|
||||
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
|
||||
)
|
||||
mock_provider_class.assert_called_once()
|
||||
|
||||
call_kwargs = mock_provider_class.call_args.kwargs
|
||||
assert call_kwargs["api_key"] == "test-gemini-key"
|
||||
assert call_kwargs["model_name"] == "models/text-embedding-004"
|
||||
|
||||
@@ -1,107 +0,0 @@
|
||||
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
|
||||
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
|
||||
class TestGoogleGenerativeAIEmbedder:
|
||||
"""Test Google Generative AI embedder configuration formats."""
|
||||
|
||||
@patch("crewai.crew.Knowledge")
|
||||
@patch("crewai.crew.ShortTermMemory")
|
||||
@patch("crewai.crew.LongTermMemory")
|
||||
@patch("crewai.crew.EntityMemory")
|
||||
def test_crew_with_google_generativeai_flat_config(
|
||||
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||
):
|
||||
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-generativeai",
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
embedder=embedder_config,
|
||||
)
|
||||
|
||||
expected_normalized_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
assert crew.embedder == expected_normalized_config
|
||||
|
||||
@patch("crewai.crew.Knowledge")
|
||||
@patch("crewai.crew.ShortTermMemory")
|
||||
@patch("crewai.crew.LongTermMemory")
|
||||
@patch("crewai.crew.EntityMemory")
|
||||
def test_crew_with_google_generativeai_nested_config(
|
||||
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
|
||||
):
|
||||
"""Test that Crew accepts google-generativeai embedder with nested config format."""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
embedder_config = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-gemini-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
|
||||
crew = Crew(
|
||||
agents=[agent],
|
||||
tasks=[task],
|
||||
embedder=embedder_config,
|
||||
)
|
||||
|
||||
assert crew.embedder == embedder_config
|
||||
|
||||
def test_generativeai_provider_spec_validation(self):
|
||||
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
|
||||
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
|
||||
|
||||
flat_spec: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
}
|
||||
assert flat_spec["provider"] == "google-generativeai"
|
||||
|
||||
nested_spec: GenerativeAiProviderSpec = {
|
||||
"provider": "google-generativeai",
|
||||
"config": {
|
||||
"api_key": "test-key",
|
||||
"model_name": "models/text-embedding-004",
|
||||
},
|
||||
}
|
||||
assert nested_spec["provider"] == "google-generativeai"
|
||||
assert nested_spec["config"]["api_key"] == "test-key"
|
||||
@@ -1,4 +1,4 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -159,11 +159,11 @@ class JWTAuthLLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""Record the call and return a predefined response."""
|
||||
self.calls.append(
|
||||
{
|
||||
@@ -192,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
|
||||
|
||||
def test_custom_llm_with_jwt_auth():
|
||||
"""Test a custom LLM implementation with JWT authentication."""
|
||||
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
||||
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") # noqa: S106
|
||||
|
||||
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||
result_llm = create_llm(jwt_llm)
|
||||
@@ -238,11 +238,11 @@ class TimeoutHandlingLLM(BaseLLM):
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
"""Simulate API calls with timeout handling and retry logic.
|
||||
|
||||
Args:
|
||||
@@ -282,35 +282,34 @@ class TimeoutHandlingLLM(BaseLLM):
|
||||
)
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on first attempt
|
||||
return "First attempt response"
|
||||
else:
|
||||
# This is a retry attempt (attempt > 0)
|
||||
# Always record retry attempts
|
||||
self.calls.append(
|
||||
{
|
||||
"retry_attempt": attempt,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
}
|
||||
)
|
||||
# Success on first attempt
|
||||
return "First attempt response"
|
||||
# This is a retry attempt (attempt > 0)
|
||||
# Always record retry attempts
|
||||
self.calls.append(
|
||||
{
|
||||
"retry_attempt": attempt,
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
}
|
||||
)
|
||||
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(
|
||||
f"LLM request failed after {self.max_retries} attempts"
|
||||
)
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
else:
|
||||
# Success on retry
|
||||
return "Response after retry"
|
||||
# Simulate a failure if fail_count > 0
|
||||
if self.fail_count > 0:
|
||||
self.fail_count -= 1
|
||||
# If we've used all retries, raise an error
|
||||
if attempt == self.max_retries - 1:
|
||||
raise TimeoutError(
|
||||
f"LLM request failed after {self.max_retries} attempts"
|
||||
)
|
||||
# Otherwise, continue to the next attempt (simulating backoff)
|
||||
continue
|
||||
# Success on retry
|
||||
return "Response after retry"
|
||||
|
||||
return "Response after retry"
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Return True to indicate that function calling is supported.
|
||||
@@ -358,3 +357,25 @@ def test_timeout_handling_llm():
|
||||
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||
llm.call("Test message")
|
||||
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||
|
||||
|
||||
class MinimalCustomLLM(BaseLLM):
|
||||
"""Minimal custom LLM implementation that doesn't override supports_function_calling."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(model="minimal-model")
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
) -> str | Any:
|
||||
return "Minimal response"
|
||||
|
||||
|
||||
def test_base_llm_supports_function_calling_default():
|
||||
"""Test that BaseLLM supports function calling by default."""
|
||||
llm = MinimalCustomLLM()
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
@@ -6,15 +6,15 @@ from datetime import datetime
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.flow_events import (
|
||||
FlowFinishedEvent,
|
||||
FlowPlotEvent,
|
||||
FlowStartedEvent,
|
||||
FlowPlotEvent,
|
||||
MethodExecutionFinishedEvent,
|
||||
MethodExecutionStartedEvent,
|
||||
)
|
||||
from crewai.flow.flow import Flow, and_, listen, or_, router, start
|
||||
|
||||
|
||||
def test_simple_sequential_flow():
|
||||
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
|
||||
assert isinstance(received_events[3], MethodExecutionStartedEvent)
|
||||
assert received_events[3].method_name == "send_welcome_message"
|
||||
assert received_events[3].params == {}
|
||||
assert received_events[3].state.sent is False
|
||||
assert getattr(received_events[3].state, "sent") is False
|
||||
|
||||
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
|
||||
assert received_events[4].method_name == "send_welcome_message"
|
||||
assert received_events[4].state.sent is True
|
||||
assert getattr(received_events[4].state, "sent") is True
|
||||
assert received_events[4].result == "Welcome, Anakin!"
|
||||
|
||||
assert isinstance(received_events[5], FlowFinishedEvent)
|
||||
@@ -894,75 +894,3 @@ def test_flow_name():
|
||||
|
||||
flow = MyFlow()
|
||||
assert flow.name == "MyFlow"
|
||||
|
||||
|
||||
def test_nested_and_or_conditions():
|
||||
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
|
||||
|
||||
Reproduces bug from issue #3719 where nested conditions are flattened,
|
||||
causing premature execution.
|
||||
"""
|
||||
execution_order = []
|
||||
|
||||
class NestedConditionFlow(Flow):
|
||||
@start()
|
||||
def method_1(self):
|
||||
execution_order.append("method_1")
|
||||
|
||||
@listen(method_1)
|
||||
def method_2(self):
|
||||
execution_order.append("method_2")
|
||||
|
||||
@router(method_2)
|
||||
def method_3(self):
|
||||
execution_order.append("method_3")
|
||||
# Choose b_condition path
|
||||
return "b_condition"
|
||||
|
||||
@listen("b_condition")
|
||||
def method_5(self):
|
||||
execution_order.append("method_5")
|
||||
|
||||
@listen(method_5)
|
||||
async def method_4(self):
|
||||
execution_order.append("method_4")
|
||||
|
||||
@listen(or_("a_condition", "b_condition"))
|
||||
async def method_6(self):
|
||||
execution_order.append("method_6")
|
||||
|
||||
@listen(
|
||||
or_(
|
||||
and_("a_condition", method_6),
|
||||
and_(method_6, method_4),
|
||||
)
|
||||
)
|
||||
def method_7(self):
|
||||
execution_order.append("method_7")
|
||||
|
||||
@listen(method_7)
|
||||
async def method_8(self):
|
||||
execution_order.append("method_8")
|
||||
|
||||
flow = NestedConditionFlow()
|
||||
flow.kickoff()
|
||||
|
||||
# Verify execution happened
|
||||
assert "method_1" in execution_order
|
||||
assert "method_2" in execution_order
|
||||
assert "method_3" in execution_order
|
||||
assert "method_5" in execution_order
|
||||
assert "method_4" in execution_order
|
||||
assert "method_6" in execution_order
|
||||
assert "method_7" in execution_order
|
||||
assert "method_8" in execution_order
|
||||
|
||||
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
|
||||
# Since b_condition was returned, method_6 triggers on b_condition
|
||||
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
|
||||
# The second condition (method_6 AND method_4) should be the one that triggers
|
||||
assert execution_order.index("method_7") > execution_order.index("method_6")
|
||||
assert execution_order.index("method_7") > execution_order.index("method_4")
|
||||
|
||||
# method_8 should execute after method_7
|
||||
assert execution_order.index("method_8") > execution_order.index("method_7")
|
||||
|
||||
@@ -711,3 +711,18 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
formatted = ollama_llm._format_messages_for_provider(original_messages)
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_supports_function_calling_with_override_true():
|
||||
llm = LLM(model="custom-model/my-model", supports_function_calling=True)
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
|
||||
def test_supports_function_calling_with_override_false():
|
||||
llm = LLM(model="gpt-4o-mini", supports_function_calling=False)
|
||||
assert llm.supports_function_calling() is False
|
||||
|
||||
|
||||
def test_supports_function_calling_without_override():
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
Reference in New Issue
Block a user