Compare commits

..

2 Commits

Author SHA1 Message Date
Devin AI
1ae3a003b6 Fix lint errors in test_custom_llm.py
- Add noqa comment for hardcoded test JWT token
- Add return statement to satisfy ruff RET503 check

Co-Authored-By: João <joao@crewai.com>
2025-10-15 03:01:54 +00:00
Devin AI
fc4b0dd923 Fix function_calling_llm support for custom models
- Add supports_function_calling() method to BaseLLM class with default True
- Add supports_function_calling parameter to LLM class to allow override of litellm check
- Add tests for both BaseLLM default and LLM override functionality
- Fixes #3708: Custom models not in litellm's list can now use function calling

Co-Authored-By: João <joao@crewai.com>
2025-10-15 02:57:03 +00:00
13 changed files with 3623 additions and 3906 deletions

View File

@@ -283,30 +283,6 @@ class Crew(FlowTrackable, BaseModel):
"may_not_set_field", "The 'id' field cannot be set by the user.", {}
)
@field_validator("embedder", mode="before")
@classmethod
def normalize_embedder_config(
cls, v: dict[str, Any] | None
) -> dict[str, Any] | None:
"""Normalize embedder config to support both flat and nested formats.
Args:
v: The embedder config to be normalized.
Returns:
The normalized embedder config with nested structure.
"""
if v is None or not isinstance(v, dict):
return v
if "provider" in v and "config" not in v:
provider = v["provider"]
config_fields = {k: val for k, val in v.items() if k != "provider"}
if config_fields:
return {"provider": provider, "config": config_fields}
return v
@field_validator("config", mode="before")
@classmethod
def check_config_type(cls, v: Json | dict[str, Any]) -> Json | dict[str, Any]:

View File

@@ -31,7 +31,7 @@ from crewai.flow.flow_visualizer import plot_flow
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import FlowExecutionData
from crewai.flow.utils import get_possible_return_constants
from crewai.utilities.printer import Printer, PrinterColor
from crewai.utilities.printer import Printer
logger = logging.getLogger(__name__)
@@ -105,7 +105,7 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
condition : Optional[Union[str, dict, Callable]], optional
Defines when the start method should execute. Can be:
- str: Name of a method that triggers this start
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this start
Default is None, meaning unconditional start.
@@ -140,18 +140,13 @@ def start(condition: str | dict | Callable | None = None) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -177,7 +172,7 @@ def listen(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the listener should execute. Can be:
- str: Name of a method that triggers this listener
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this listener
Returns
@@ -205,18 +200,13 @@ def listen(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -243,7 +233,7 @@ def router(condition: str | dict | Callable) -> Callable:
condition : Union[str, dict, Callable]
Specifies when the router should execute. Can be:
- str: Name of a method that triggers this router
- dict: Result from or_() or and_(), including nested conditions
- dict: Contains "type" ("AND"/"OR") and "methods" (list of triggers)
- Callable: A method reference that triggers this router
Returns
@@ -276,18 +266,13 @@ def router(condition: str | dict | Callable) -> Callable:
if isinstance(condition, str):
func.__trigger_methods__ = [condition]
func.__condition_type__ = "OR"
elif isinstance(condition, dict) and "type" in condition:
if "conditions" in condition:
func.__trigger_condition__ = condition
func.__trigger_methods__ = _extract_all_methods(condition)
func.__condition_type__ = condition["type"]
elif "methods" in condition:
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
else:
raise ValueError(
"Condition dict must contain 'conditions' or 'methods'"
)
elif (
isinstance(condition, dict)
and "type" in condition
and "methods" in condition
):
func.__trigger_methods__ = condition["methods"]
func.__condition_type__ = condition["type"]
elif callable(condition) and hasattr(condition, "__name__"):
func.__trigger_methods__ = [condition.__name__]
func.__condition_type__ = "OR"
@@ -313,15 +298,14 @@ def or_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries (nested conditions)
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "OR", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
{"type": "OR", "methods": list_of_method_names}
Raises
------
@@ -333,22 +317,18 @@ def or_(*conditions: str | dict | Callable) -> dict:
>>> @listen(or_("success", "timeout"))
>>> def handle_completion(self):
... pass
>>> @listen(or_(and_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_conditions: list[str | dict[str, Any]] = []
methods = []
for condition in conditions:
if isinstance(condition, dict):
processed_conditions.append(condition)
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
elif isinstance(condition, str):
processed_conditions.append(condition)
methods.append(condition)
elif callable(condition):
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in or_()")
return {"type": "OR", "conditions": processed_conditions}
return {"type": "OR", "methods": methods}
def and_(*conditions: str | dict | Callable) -> dict:
@@ -364,15 +344,14 @@ def and_(*conditions: str | dict | Callable) -> dict:
*conditions : Union[str, dict, Callable]
Variable number of conditions that can be:
- str: Method names
- dict: Existing condition dictionaries (nested conditions)
- dict: Existing condition dictionaries
- Callable: Method references
Returns
-------
dict
A condition dictionary with format:
{"type": "AND", "conditions": list_of_conditions}
where each condition can be a string (method name) or a nested dict
{"type": "AND", "methods": list_of_method_names}
Raises
------
@@ -384,69 +363,18 @@ def and_(*conditions: str | dict | Callable) -> dict:
>>> @listen(and_("validated", "processed"))
>>> def handle_complete_data(self):
... pass
>>> @listen(and_(or_("step1", "step2"), "step3"))
>>> def handle_nested(self):
... pass
"""
processed_conditions: list[str | dict[str, Any]] = []
methods = []
for condition in conditions:
if isinstance(condition, dict):
processed_conditions.append(condition)
if isinstance(condition, dict) and "methods" in condition:
methods.extend(condition["methods"])
elif isinstance(condition, str):
processed_conditions.append(condition)
methods.append(condition)
elif callable(condition):
processed_conditions.append(getattr(condition, "__name__", repr(condition)))
methods.append(getattr(condition, "__name__", repr(condition)))
else:
raise ValueError("Invalid condition in and_()")
return {"type": "AND", "conditions": processed_conditions}
def _normalize_condition(condition: str | dict | list) -> dict:
"""Normalize a condition to standard format with 'conditions' key.
Args:
condition: Can be a string (method name), dict (condition), or list
Returns:
Normalized dict with 'type' and 'conditions' keys
"""
if isinstance(condition, str):
return {"type": "OR", "conditions": [condition]}
if isinstance(condition, dict):
if "conditions" in condition:
return condition
if "methods" in condition:
return {"type": condition["type"], "conditions": condition["methods"]}
return condition
if isinstance(condition, list):
return {"type": "OR", "conditions": condition}
return {"type": "OR", "conditions": [condition]}
def _extract_all_methods(condition: str | dict | list) -> list[str]:
"""Extract all method names from a condition (including nested).
Args:
condition: Can be a string, dict, or list
Returns:
List of all method names in the condition tree
"""
if isinstance(condition, str):
return [condition]
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
methods = []
for sub_cond in normalized.get("conditions", []):
methods.extend(_extract_all_methods(sub_cond))
return methods
if isinstance(condition, list):
methods = []
for item in condition:
methods.extend(_extract_all_methods(item))
return methods
return []
return {"type": "AND", "methods": methods}
class FlowMeta(type):
@@ -474,10 +402,7 @@ class FlowMeta(type):
if hasattr(attr_value, "__trigger_methods__"):
methods = attr_value.__trigger_methods__
condition_type = getattr(attr_value, "__condition_type__", "OR")
if hasattr(attr_value, "__trigger_condition__"):
listeners[attr_name] = attr_value.__trigger_condition__
else:
listeners[attr_name] = (condition_type, methods)
listeners[attr_name] = (condition_type, methods)
if (
hasattr(attr_value, "__is_router__")
@@ -897,7 +822,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
self._method_outputs.clear()
self._pending_and_listeners.clear()
else:
# We're restoring from persistence, set the flag
self._is_execution_resuming = True
@@ -1162,16 +1086,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
for method_name in self._start_methods:
# Check if this start method is triggered by the current trigger
if method_name in self._listeners:
condition_data = self._listeners[method_name]
should_trigger = False
if isinstance(condition_data, tuple):
_, trigger_methods = condition_data
should_trigger = current_trigger in trigger_methods
elif isinstance(condition_data, dict):
all_methods = _extract_all_methods(condition_data)
should_trigger = current_trigger in all_methods
if should_trigger:
condition_type, trigger_methods = self._listeners[
method_name
]
if current_trigger in trigger_methods:
# Only execute if this is a cycle (method was already completed)
if method_name in self._completed_methods:
# For router-triggered start methods in cycles, temporarily clear resumption flag
@@ -1181,51 +1099,6 @@ class Flow(Generic[T], metaclass=FlowMeta):
await self._execute_start_method(method_name)
self._is_execution_resuming = was_resuming
def _evaluate_condition(
self, condition: str | dict, trigger_method: str, listener_name: str
) -> bool:
"""Recursively evaluate a condition (simple or nested).
Args:
condition: Can be a string (method name) or dict (nested condition)
trigger_method: The method that just completed
listener_name: Name of the listener being evaluated
Returns:
True if the condition is satisfied, False otherwise
"""
if isinstance(condition, str):
return condition == trigger_method
if isinstance(condition, dict):
normalized = _normalize_condition(condition)
cond_type = normalized.get("type", "OR")
sub_conditions = normalized.get("conditions", [])
if cond_type == "OR":
return any(
self._evaluate_condition(sub_cond, trigger_method, listener_name)
for sub_cond in sub_conditions
)
if cond_type == "AND":
pending_key = f"{listener_name}:{id(condition)}"
if pending_key not in self._pending_and_listeners:
all_methods = set(_extract_all_methods(condition))
self._pending_and_listeners[pending_key] = all_methods
if trigger_method in self._pending_and_listeners[pending_key]:
self._pending_and_listeners[pending_key].discard(trigger_method)
if not self._pending_and_listeners[pending_key]:
self._pending_and_listeners.pop(pending_key, None)
return True
return False
return False
def _find_triggered_methods(
self, trigger_method: str, router_only: bool
) -> list[str]:
@@ -1233,7 +1106,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
Finds all methods that should be triggered based on conditions.
This internal method evaluates both OR and AND conditions to determine
which methods should be executed next in the flow. Supports nested conditions.
which methods should be executed next in the flow.
Parameters
----------
@@ -1250,13 +1123,14 @@ class Flow(Generic[T], metaclass=FlowMeta):
Notes
-----
- Handles both OR and AND conditions, including nested combinations
- Handles both OR and AND conditions:
* OR: Triggers if any condition is met
* AND: Triggers only when all conditions are met
- Maintains state for AND conditions using _pending_and_listeners
- Separates router and normal listener evaluation
"""
triggered = []
for listener_name, condition_data in self._listeners.items():
for listener_name, (condition_type, methods) in self._listeners.items():
is_router = listener_name in self._routers
if router_only != is_router:
@@ -1265,29 +1139,23 @@ class Flow(Generic[T], metaclass=FlowMeta):
if not router_only and listener_name in self._start_methods:
continue
if isinstance(condition_data, tuple):
condition_type, methods = condition_data
if condition_type == "OR":
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(
trigger_method
)
if not self._pending_and_listeners[listener_name]:
triggered.append(listener_name)
self._pending_and_listeners.pop(listener_name, None)
elif isinstance(condition_data, dict):
if self._evaluate_condition(
condition_data, trigger_method, listener_name
):
if condition_type == "OR":
# If the trigger_method matches any in methods, run this
if trigger_method in methods:
triggered.append(listener_name)
elif condition_type == "AND":
# Initialize pending methods for this listener if not already done
if listener_name not in self._pending_and_listeners:
self._pending_and_listeners[listener_name] = set(methods)
# Remove the trigger method from pending methods
if trigger_method in self._pending_and_listeners[listener_name]:
self._pending_and_listeners[listener_name].discard(trigger_method)
if not self._pending_and_listeners[listener_name]:
# All required methods have been executed
triggered.append(listener_name)
# Reset pending methods for this listener
self._pending_and_listeners.pop(listener_name, None)
return triggered
@@ -1350,7 +1218,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
raise
def _log_flow_event(
self, message: str, color: PrinterColor | None = "yellow", level: str = "info"
self, message: str, color: str = "yellow", level: str = "info"
) -> None:
"""Centralized logging method for flow events.

View File

@@ -299,6 +299,7 @@ class LLM(BaseLLM):
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
supports_function_calling: bool | None = None,
**kwargs,
):
self.model = model
@@ -325,6 +326,7 @@ class LLM(BaseLLM):
self.additional_params = kwargs
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self._supports_function_calling_override = supports_function_calling
litellm.drop_params = True
@@ -1197,6 +1199,9 @@ class LLM(BaseLLM):
)
def supports_function_calling(self) -> bool:
if self._supports_function_calling_override is not None:
return self._supports_function_calling_override
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(

View File

@@ -9,6 +9,7 @@ from typing import Any, Final
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
DEFAULT_SUPPORTS_FUNCTION_CALLING: Final[bool] = True
class BaseLLM(ABC):
@@ -82,6 +83,14 @@ class BaseLLM(ABC):
RuntimeError: If the LLM request fails for other reasons.
"""
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
True if the LLM supports function calling, False otherwise.
"""
return DEFAULT_SUPPORTS_FUNCTION_CALLING
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.

View File

@@ -228,11 +228,8 @@ def build_embedder_from_dict(spec):
"""Build an embedding function instance from a dictionary specification.
Args:
spec: A dictionary with 'provider' and optionally 'config' keys.
Supports two formats:
Nested format (recommended):
{
spec: A dictionary with 'provider' and 'config' keys.
Example: {
"provider": "openai",
"config": {
"api_key": "sk-...",
@@ -240,13 +237,6 @@ def build_embedder_from_dict(spec):
}
}
Flat format (for backward compatibility):
{
"provider": "openai",
"api_key": "sk-...",
"model_name": "text-embedding-3-small"
}
Returns:
An instance of the appropriate embedding function.
@@ -276,10 +266,7 @@ def build_embedder_from_dict(spec):
except (ImportError, AttributeError, ValueError) as e:
raise ImportError(f"Failed to import provider {provider_name}: {e}") from e
if "config" in spec:
provider_config = spec["config"]
else:
provider_config = {k: v for k, v in spec.items() if k != "provider"}
provider_config = spec.get("config", {})
if provider_name == "custom" and "embedding_callable" not in provider_config:
raise ValueError("Custom provider requires 'embedding_callable' in config")

View File

@@ -13,10 +13,10 @@ class GenerativeAiProviderConfig(TypedDict, total=False):
task_type: Annotated[str, "RETRIEVAL_DOCUMENT"]
class GenerativeAiProviderSpec(TypedDict, total=False):
class GenerativeAiProviderSpec(TypedDict):
"""Google Generative AI provider specification."""
provider: Required[Literal["google-generativeai"]]
provider: Literal["google-generativeai"]
config: GenerativeAiProviderConfig

View File

@@ -1,11 +1,6 @@
"""Utility for colored console output."""
from __future__ import annotations
from typing import TYPE_CHECKING, Final, Literal, NamedTuple
if TYPE_CHECKING:
from _typeshed import SupportsWrite
from typing import Final, Literal, NamedTuple
PrinterColor = Literal[
"purple",
@@ -59,22 +54,13 @@ class Printer:
@staticmethod
def print(
content: str | list[ColoredText],
color: PrinterColor | None = None,
sep: str | None = " ",
end: str | None = "\n",
file: SupportsWrite[str] | None = None,
flush: Literal[False] = False,
content: str | list[ColoredText], color: PrinterColor | None = None
) -> None:
"""Prints content to the console with optional color formatting.
Args:
content: Either a string or a list of ColoredText objects for multicolor output.
color: Optional color for the text when content is a string. Ignored when content is a list.
sep: Separator to use between the text and color.
end: String appended after the last value.
file: A file-like object (stream); defaults to the current sys.stdout.
flush: Whether to forcibly flush the stream.
"""
if isinstance(content, str):
content = [ColoredText(content, color)]
@@ -82,9 +68,5 @@ class Printer:
"".join(
f"{_COLOR_CODES[c.color] if c.color else ''}{c.text}{RESET}"
for c in content
),
sep=sep,
end=end,
file=file,
flush=flush,
)
)

View File

@@ -242,61 +242,3 @@ class TestEmbeddingFactory:
mock_build_from_provider.assert_called_once_with(mock_provider)
assert result == mock_embedding_function
mock_import.assert_not_called()
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_generativeai_nested_config(self, mock_import):
"""Test building Google Generative AI embedder with nested config format."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-gemini-key"
assert call_kwargs["model_name"] == "models/text-embedding-004"
@patch("crewai.rag.embeddings.factory.import_and_validate_definition")
def test_build_embedder_google_generativeai_flat_config(self, mock_import):
"""Test building Google Generative AI embedder with flat config format (issue #3741)."""
mock_provider_class = MagicMock()
mock_provider_instance = MagicMock()
mock_embedding_function = MagicMock()
mock_import.return_value = mock_provider_class
mock_provider_class.return_value = mock_provider_instance
mock_provider_instance.embedding_callable.return_value = mock_embedding_function
config = {
"provider": "google-generativeai",
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
}
build_embedder(config)
mock_import.assert_called_once_with(
"crewai.rag.embeddings.providers.google.generative_ai.GenerativeAiProvider"
)
mock_provider_class.assert_called_once()
call_kwargs = mock_provider_class.call_args.kwargs
assert call_kwargs["api_key"] == "test-gemini-key"
assert call_kwargs["model_name"] == "models/text-embedding-004"

View File

@@ -1,107 +0,0 @@
"""Tests for Google Generative AI embedder configuration (issue #3741)."""
from unittest.mock import MagicMock, patch
import pytest
from crewai import Agent, Crew, Task
class TestGoogleGenerativeAIEmbedder:
"""Test Google Generative AI embedder configuration formats."""
@patch("crewai.crew.Knowledge")
@patch("crewai.crew.ShortTermMemory")
@patch("crewai.crew.LongTermMemory")
@patch("crewai.crew.EntityMemory")
def test_crew_with_google_generativeai_flat_config(
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
):
"""Test that Crew accepts google-generativeai embedder with flat config format (issue #3741)."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
embedder_config = {
"provider": "google-generativeai",
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
}
crew = Crew(
agents=[agent],
tasks=[task],
embedder=embedder_config,
)
expected_normalized_config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
assert crew.embedder == expected_normalized_config
@patch("crewai.crew.Knowledge")
@patch("crewai.crew.ShortTermMemory")
@patch("crewai.crew.LongTermMemory")
@patch("crewai.crew.EntityMemory")
def test_crew_with_google_generativeai_nested_config(
self, mock_entity_memory, mock_long_term_memory, mock_short_term_memory, mock_knowledge
):
"""Test that Crew accepts google-generativeai embedder with nested config format."""
agent = Agent(
role="Test Agent",
goal="Test goal",
backstory="Test backstory",
)
task = Task(
description="Test task",
expected_output="Test output",
agent=agent,
)
embedder_config = {
"provider": "google-generativeai",
"config": {
"api_key": "test-gemini-key",
"model_name": "models/text-embedding-004",
},
}
crew = Crew(
agents=[agent],
tasks=[task],
embedder=embedder_config,
)
assert crew.embedder == embedder_config
def test_generativeai_provider_spec_validation(self):
"""Test that GenerativeAiProviderSpec validates correctly with optional config."""
from crewai.rag.embeddings.types import GenerativeAiProviderSpec
flat_spec: GenerativeAiProviderSpec = {
"provider": "google-generativeai",
}
assert flat_spec["provider"] == "google-generativeai"
nested_spec: GenerativeAiProviderSpec = {
"provider": "google-generativeai",
"config": {
"api_key": "test-key",
"model_name": "models/text-embedding-004",
},
}
assert nested_spec["provider"] == "google-generativeai"
assert nested_spec["config"]["api_key"] == "test-key"

View File

@@ -1,4 +1,4 @@
from typing import Any, Dict, List, Optional, Union
from typing import Any
import pytest
@@ -159,11 +159,11 @@ class JWTAuthLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Record the call and return a predefined response."""
self.calls.append(
{
@@ -192,7 +192,7 @@ class JWTAuthLLM(BaseLLM):
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") # noqa: S106
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
@@ -238,11 +238,11 @@ class TimeoutHandlingLLM(BaseLLM):
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
"""Simulate API calls with timeout handling and retry logic.
Args:
@@ -282,35 +282,34 @@ class TimeoutHandlingLLM(BaseLLM):
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on first attempt
return "First attempt response"
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Success on first attempt
return "First attempt response"
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
# Success on retry
return "Response after retry"
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
@@ -358,3 +357,25 @@ def test_timeout_handling_llm():
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
llm.call("Test message")
assert len(llm.calls) == 2 # Initial call + failed retry attempt
class MinimalCustomLLM(BaseLLM):
"""Minimal custom LLM implementation that doesn't override supports_function_calling."""
def __init__(self):
super().__init__(model="minimal-model")
def call(
self,
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
) -> str | Any:
return "Minimal response"
def test_base_llm_supports_function_calling_default():
"""Test that BaseLLM supports function calling by default."""
llm = MinimalCustomLLM()
assert llm.supports_function_calling() is True

View File

@@ -6,15 +6,15 @@ from datetime import datetime
import pytest
from pydantic import BaseModel
from crewai.flow.flow import Flow, and_, listen, or_, router, start
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.flow_events import (
FlowFinishedEvent,
FlowPlotEvent,
FlowStartedEvent,
FlowPlotEvent,
MethodExecutionFinishedEvent,
MethodExecutionStartedEvent,
)
from crewai.flow.flow import Flow, and_, listen, or_, router, start
def test_simple_sequential_flow():
@@ -679,11 +679,11 @@ def test_structured_flow_event_emission():
assert isinstance(received_events[3], MethodExecutionStartedEvent)
assert received_events[3].method_name == "send_welcome_message"
assert received_events[3].params == {}
assert received_events[3].state.sent is False
assert getattr(received_events[3].state, "sent") is False
assert isinstance(received_events[4], MethodExecutionFinishedEvent)
assert received_events[4].method_name == "send_welcome_message"
assert received_events[4].state.sent is True
assert getattr(received_events[4].state, "sent") is True
assert received_events[4].result == "Welcome, Anakin!"
assert isinstance(received_events[5], FlowFinishedEvent)
@@ -894,75 +894,3 @@ def test_flow_name():
flow = MyFlow()
assert flow.name == "MyFlow"
def test_nested_and_or_conditions():
"""Test nested conditions like or_(and_(A, B), and_(C, D)).
Reproduces bug from issue #3719 where nested conditions are flattened,
causing premature execution.
"""
execution_order = []
class NestedConditionFlow(Flow):
@start()
def method_1(self):
execution_order.append("method_1")
@listen(method_1)
def method_2(self):
execution_order.append("method_2")
@router(method_2)
def method_3(self):
execution_order.append("method_3")
# Choose b_condition path
return "b_condition"
@listen("b_condition")
def method_5(self):
execution_order.append("method_5")
@listen(method_5)
async def method_4(self):
execution_order.append("method_4")
@listen(or_("a_condition", "b_condition"))
async def method_6(self):
execution_order.append("method_6")
@listen(
or_(
and_("a_condition", method_6),
and_(method_6, method_4),
)
)
def method_7(self):
execution_order.append("method_7")
@listen(method_7)
async def method_8(self):
execution_order.append("method_8")
flow = NestedConditionFlow()
flow.kickoff()
# Verify execution happened
assert "method_1" in execution_order
assert "method_2" in execution_order
assert "method_3" in execution_order
assert "method_5" in execution_order
assert "method_4" in execution_order
assert "method_6" in execution_order
assert "method_7" in execution_order
assert "method_8" in execution_order
# Critical assertion: method_7 should only execute AFTER both method_6 AND method_4
# Since b_condition was returned, method_6 triggers on b_condition
# method_7 requires: (a_condition AND method_6) OR (method_6 AND method_4)
# The second condition (method_6 AND method_4) should be the one that triggers
assert execution_order.index("method_7") > execution_order.index("method_6")
assert execution_order.index("method_7") > execution_order.index("method_4")
# method_8 should execute after method_7
assert execution_order.index("method_8") > execution_order.index("method_7")

View File

@@ -711,3 +711,18 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
formatted = ollama_llm._format_messages_for_provider(original_messages)
assert formatted == original_messages
def test_supports_function_calling_with_override_true():
llm = LLM(model="custom-model/my-model", supports_function_calling=True)
assert llm.supports_function_calling() is True
def test_supports_function_calling_with_override_false():
llm = LLM(model="gpt-4o-mini", supports_function_calling=False)
assert llm.supports_function_calling() is False
def test_supports_function_calling_without_override():
llm = LLM(model="gpt-4o-mini")
assert llm.supports_function_calling() is True

6815
uv.lock generated

File diff suppressed because it is too large Load Diff