Compare commits

..

1 Commits

Author SHA1 Message Date
Devin AI
28147528f9 fix: add async method validation to crew decorators
This commit adds validation to detect and reject async methods in crew
decorators (@agent, @task, @crew, @llm, @tool, @callback, @cache_handler,
@before_kickoff, @after_kickoff).

Previously, decorating async methods would silently fail at runtime with
confusing errors like "'coroutine' object has no attribute 'name'".

Now, a clear TypeError is raised at decoration time with:
- The specific decorator name that doesn't support async
- The method name that was incorrectly defined as async
- Helpful suggestions for workarounds

Fixes #3988

Co-Authored-By: João <joao@crewai.com>
2025-11-28 17:51:12 +00:00
16 changed files with 293 additions and 334 deletions

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube>=15.0.0",
"requests>=2.32.5",
"docker>=7.1.0",
"crewai==1.6.1",
"crewai==1.6.0",
"lancedb>=0.5.4",
"tiktoken>=0.8.0",
"beautifulsoup4>=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.6.1"
__version__ = "1.6.0"

View File

@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.6.1",
"crewai-tools==1.6.0",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.6.1"
__version__ = "1.6.0"
_telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.1"
"crewai[tools]==1.6.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.1"
"crewai[tools]==1.6.0"
]
[project.scripts]

View File

@@ -406,100 +406,46 @@ class LLM(BaseLLM):
instance.is_litellm = True
return instance
@classmethod
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
"""Check if a model name matches provider-specific patterns.
This allows supporting models that aren't in the hardcoded constants list,
including "latest" versions and new models that follow provider naming conventions.
Args:
model: The model name to check
provider: The provider to check against (canonical name)
Returns:
True if the model matches the provider's naming pattern, False otherwise
"""
model_lower = model.lower()
if provider == "openai":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
)
if provider == "anthropic" or provider == "claude":
return any(
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
)
if provider == "gemini" or provider == "google":
return any(
model_lower.startswith(prefix)
for prefix in ["gemini-", "gemma-", "learnlm-"]
)
if provider == "bedrock":
return "." in model_lower
if provider == "azure":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
return False
@classmethod
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
"""Validate if a model name exists in the provider's constants or matches provider patterns.
This method first checks the hardcoded constants list for known models.
If not found, it falls back to pattern matching to support new models,
"latest" versions, and models that follow provider naming conventions.
"""Validate if a model name exists in the provider's constants.
Args:
model: The model name to validate
provider: The provider to check against (canonical name)
Returns:
True if the model exists in constants or matches provider patterns, False otherwise
True if the model exists in the provider's constants, False otherwise
"""
if provider == "openai" and model in OPENAI_MODELS:
return True
if provider == "openai":
return model in OPENAI_MODELS
if (
provider == "anthropic" or provider == "claude"
) and model in ANTHROPIC_MODELS:
return True
if provider == "anthropic" or provider == "claude":
return model in ANTHROPIC_MODELS
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
return True
if provider == "gemini":
return model in GEMINI_MODELS
if provider == "bedrock" and model in BEDROCK_MODELS:
return True
if provider == "bedrock":
return model in BEDROCK_MODELS
if provider == "azure":
# azure does not provide a list of available models, determine a better way to handle this
return True
# Fallback to pattern matching for models not in constants
return cls._matches_provider_pattern(model, provider)
return False
@classmethod
def _infer_provider_from_model(cls, model: str) -> str:
"""Infer the provider from the model name.
This method first checks the hardcoded constants list for known models.
If not found, it uses pattern matching to infer the provider from model name patterns.
This allows supporting new models and "latest" versions without hardcoding.
Args:
model: The model name without provider prefix
Returns:
The inferred provider name, defaults to "openai"
"""
if model in OPENAI_MODELS:
return "openai"
@@ -1753,14 +1699,12 @@ class LLM(BaseLLM):
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
logit_bias=(
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
),
response_format=(
copy.deepcopy(self.response_format, memo)
if self.response_format
else None
),
logit_bias=copy.deepcopy(self.logit_bias, memo)
if self.logit_bias
else None,
response_format=copy.deepcopy(self.response_format, memo)
if self.response_format
else None,
seed=self.seed,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,

View File

@@ -182,8 +182,6 @@ OPENAI_MODELS: list[OpenAIModels] = [
AnthropicModels: TypeAlias = Literal[
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -210,8 +208,6 @@ AnthropicModels: TypeAlias = Literal[
"claude-3-haiku-20240307",
]
ANTHROPIC_MODELS: list[AnthropicModels] = [
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -456,7 +452,6 @@ BedrockModels: TypeAlias = Literal[
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
@@ -529,7 +524,6 @@ BEDROCK_MODELS: list[BedrockModels] = [
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",

View File

@@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.converter import generate_model_description
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
@@ -27,7 +26,6 @@ try:
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
JsonSchemaFormat,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import (
@@ -280,16 +278,13 @@ class AzureCompletion(BaseLLM):
}
if response_model and self.is_openai_model:
model_description = generate_model_description(response_model)
json_schema_info = model_description["json_schema"]
json_schema_name = json_schema_info["name"]
params["response_format"] = JsonSchemaFormat(
name=json_schema_name,
schema=json_schema_info["schema"],
description=f"Schema for {json_schema_name}",
strict=json_schema_info["strict"],
)
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"schema": response_model.model_json_schema(),
},
}
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
@@ -316,8 +311,8 @@ class AzureCompletion(BaseLLM):
params["tool_choice"] = "auto"
additional_params = self.additional_params
additional_drop_params = additional_params.get("additional_drop_params")
drop_params = additional_params.get("drop_params")
additional_drop_params = additional_params.get('additional_drop_params')
drop_params = additional_params.get('drop_params')
if drop_params and isinstance(additional_drop_params, list):
for drop_param in additional_drop_params:

View File

@@ -5,15 +5,9 @@ from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
from crewai.project.utils import memoize
if TYPE_CHECKING:
from crewai import Agent, Crew, Task
from crewai.project.wrappers import (
AfterKickoffMethod,
AgentMethod,
@@ -30,6 +24,31 @@ from crewai.project.wrappers import (
)
if TYPE_CHECKING:
from crewai import Agent, Crew, Task
def _check_async_method(meth: Callable[..., Any], decorator_name: str) -> None:
"""Check if a method is async and raise an error if so.
Args:
meth: The method to check.
decorator_name: The name of the decorator for the error message.
Raises:
TypeError: If the method is an async function.
"""
if asyncio.iscoroutinefunction(meth):
raise TypeError(
f"The @{decorator_name} decorator does not support async methods. "
f"Method '{meth.__name__}' is defined as async. "
f"Please use a synchronous method instead. "
f"If you need to perform async operations, consider: "
f"1) Creating tools/resources synchronously before crew execution, or "
f"2) Using asyncio.run() within a sync method for isolated async calls."
)
P = ParamSpec("P")
P2 = ParamSpec("P2")
R = TypeVar("R")
@@ -46,7 +65,11 @@ def before_kickoff(meth: Callable[P, R]) -> BeforeKickoffMethod[P, R]:
Returns:
A wrapped method marked for before kickoff execution.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "before_kickoff")
return BeforeKickoffMethod(meth)
@@ -58,7 +81,11 @@ def after_kickoff(meth: Callable[P, R]) -> AfterKickoffMethod[P, R]:
Returns:
A wrapped method marked for after kickoff execution.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "after_kickoff")
return AfterKickoffMethod(meth)
@@ -70,7 +97,11 @@ def task(meth: Callable[P, TaskResultT]) -> TaskMethod[P, TaskResultT]:
Returns:
A wrapped method marked as a task with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "task")
return TaskMethod(memoize(meth))
@@ -82,7 +113,11 @@ def agent(meth: Callable[P, R]) -> AgentMethod[P, R]:
Returns:
A wrapped method marked as an agent with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "agent")
return AgentMethod(memoize(meth))
@@ -94,7 +129,11 @@ def llm(meth: Callable[P, R]) -> LLMMethod[P, R]:
Returns:
A wrapped method marked as an LLM provider with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "llm")
return LLMMethod(memoize(meth))
@@ -130,7 +169,11 @@ def tool(meth: Callable[P, R]) -> ToolMethod[P, R]:
Returns:
A wrapped method marked as a tool with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "tool")
return ToolMethod(memoize(meth))
@@ -142,7 +185,11 @@ def callback(meth: Callable[P, R]) -> CallbackMethod[P, R]:
Returns:
A wrapped method marked as a callback with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "callback")
return CallbackMethod(memoize(meth))
@@ -154,27 +201,14 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
Returns:
A wrapped method marked as a cache handler with memoization.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "cache_handler")
return CacheHandlerMethod(memoize(meth))
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Call a method, awaiting it if async and running in an event loop."""
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
@overload
def crew(
meth: Callable[Concatenate[SelfT, P], Crew],
@@ -193,7 +227,11 @@ def crew(
Returns:
A wrapped method that instantiates tasks and agents before execution.
Raises:
TypeError: If the method is an async function.
"""
_check_async_method(meth, "crew")
@wraps(meth)
def wrapper(self: CrewInstance, *args: Any, **kwargs: Any) -> Crew:
@@ -217,7 +255,7 @@ def crew(
# Instantiate tasks in order
for _, task_method in tasks:
task_instance = _call_method(task_method, self)
task_instance = task_method(self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance and agent_instance.role not in agent_roles:
@@ -226,7 +264,7 @@ def crew(
# Instantiate agents not included by tasks
for _, agent_method in agents:
agent_instance = _call_method(agent_method, self)
agent_instance = agent_method(self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)
@@ -234,7 +272,7 @@ def crew(
self.agents = instantiated_agents
self.tasks = instantiated_tasks
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
crew_instance = meth(self, *args, **kwargs)
def callback_wrapper(
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance

View File

@@ -1,8 +1,7 @@
"""Utility functions for the crewai project module."""
from collections.abc import Callable, Coroutine
from collections.abc import Callable
from functools import wraps
import inspect
from typing import Any, ParamSpec, TypeVar, cast
from pydantic import BaseModel
@@ -38,8 +37,8 @@ def _make_hashable(arg: Any) -> Any:
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a method by caching its results based on arguments.
Handles both sync and async methods. Pydantic BaseModel instances are
converted to JSON strings before hashing for cache lookup.
Handles Pydantic BaseModel instances by converting them to JSON strings
before hashing for cache lookup.
Args:
meth: The method to memoize.
@@ -47,16 +46,18 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
Returns:
A memoized version of the method that caches results.
"""
if inspect.iscoroutinefunction(meth):
return cast(Callable[P, R], _memoize_async(meth))
return _memoize_sync(meth)
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a synchronous method."""
@wraps(meth)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"""Wrapper that converts arguments to hashable form before caching.
Args:
*args: Positional arguments to the memoized method.
**kwargs: Keyword arguments to the memoized method.
Returns:
The result of the memoized method call.
"""
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
@@ -72,27 +73,3 @@ def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
return result
return cast(Callable[P, R], wrapper)
def _memoize_async(
meth: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
"""Memoize an async method."""
@wraps(meth)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
)
cache_key = str((hashable_args, hashable_kwargs))
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
if cached_result is not None:
return cached_result
result = await meth(*args, **kwargs)
cache.add(tool=meth.__name__, input=cache_key, output=result)
return result
return wrapper

View File

@@ -2,10 +2,8 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import partial
import inspect
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -134,22 +132,6 @@ class CrewClass(Protocol):
crew: Callable[..., Crew]
def _resolve_result(result: Any) -> Any:
"""Resolve a potentially async result to its value."""
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
@@ -180,12 +162,7 @@ class DecoratedMethod(Generic[P, R]):
"""
if obj is None:
return self
inner = partial(self._meth, obj)
def _bound(*args: Any, **kwargs: Any) -> R:
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
return result
bound = partial(self._meth, obj)
for attr in (
"is_agent",
"is_llm",
@@ -197,8 +174,8 @@ class DecoratedMethod(Generic[P, R]):
"is_crew",
):
if hasattr(self, attr):
setattr(_bound, attr, getattr(self, attr))
return _bound
setattr(bound, attr, getattr(self, attr))
return bound
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
@@ -259,7 +236,6 @@ class BoundTaskMethod(Generic[TaskResultT]):
The task result with name ensured.
"""
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
result = _resolve_result(result)
return self._task_method.ensure_task_name(result)
@@ -316,9 +292,7 @@ class TaskMethod(Generic[P, TaskResultT]):
Returns:
The task instance with name set if not already provided.
"""
result = self._meth(*args, **kwargs)
result = _resolve_result(result)
return self.ensure_task_name(result)
return self.ensure_task_name(self._meth(*args, **kwargs))
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.

View File

@@ -0,0 +1,169 @@
"""Tests for async method validation in crew decorators."""
import pytest
from crewai.project import (
after_kickoff,
agent,
before_kickoff,
cache_handler,
callback,
crew,
llm,
task,
tool,
)
class TestAsyncDecoratorValidation:
"""Test that decorators properly reject async methods with clear error messages."""
def test_agent_decorator_rejects_async_method(self):
"""Test that @agent decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@agent
async def async_agent(self):
return None
assert "@agent decorator does not support async methods" in str(exc_info.value)
assert "async_agent" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_task_decorator_rejects_async_method(self):
"""Test that @task decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@task
async def async_task(self):
return None
assert "@task decorator does not support async methods" in str(exc_info.value)
assert "async_task" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_crew_decorator_rejects_async_method(self):
"""Test that @crew decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@crew
async def async_crew(self):
return None
assert "@crew decorator does not support async methods" in str(exc_info.value)
assert "async_crew" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_llm_decorator_rejects_async_method(self):
"""Test that @llm decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@llm
async def async_llm(self):
return None
assert "@llm decorator does not support async methods" in str(exc_info.value)
assert "async_llm" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_tool_decorator_rejects_async_method(self):
"""Test that @tool decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@tool
async def async_tool(self):
return None
assert "@tool decorator does not support async methods" in str(exc_info.value)
assert "async_tool" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_callback_decorator_rejects_async_method(self):
"""Test that @callback decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@callback
async def async_callback(self):
return None
assert "@callback decorator does not support async methods" in str(exc_info.value)
assert "async_callback" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_cache_handler_decorator_rejects_async_method(self):
"""Test that @cache_handler decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@cache_handler
async def async_cache_handler(self):
return None
assert "@cache_handler decorator does not support async methods" in str(
exc_info.value
)
assert "async_cache_handler" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_before_kickoff_decorator_rejects_async_method(self):
"""Test that @before_kickoff decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@before_kickoff
async def async_before_kickoff(self, inputs):
return inputs
assert "@before_kickoff decorator does not support async methods" in str(
exc_info.value
)
assert "async_before_kickoff" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_after_kickoff_decorator_rejects_async_method(self):
"""Test that @after_kickoff decorator raises TypeError for async methods."""
with pytest.raises(TypeError) as exc_info:
@after_kickoff
async def async_after_kickoff(self, outputs):
return outputs
assert "@after_kickoff decorator does not support async methods" in str(
exc_info.value
)
assert "async_after_kickoff" in str(exc_info.value)
assert "synchronous method" in str(exc_info.value)
def test_sync_methods_still_work(self):
"""Test that synchronous methods are still properly decorated."""
from crewai import Agent, Task
@agent
def sync_agent(self):
return Agent(
role="Test Agent", goal="Test Goal", backstory="Test Backstory"
)
@task
def sync_task(self):
return Task(description="Test Description", expected_output="Test Output")
class TestCrew:
pass
test_instance = TestCrew()
agent_result = sync_agent(test_instance)
task_result = sync_task(test_instance)
assert agent_result.role == "Test Agent"
assert task_result.description == "Test Description"
def test_error_message_includes_workaround_suggestions(self):
"""Test that error messages include helpful workaround suggestions."""
with pytest.raises(TypeError) as exc_info:
@agent
async def async_agent_with_tools(self):
return None
error_message = str(exc_info.value)
assert "Creating tools/resources synchronously" in error_message
assert "asyncio.run()" in error_message

View File

@@ -243,11 +243,7 @@ def test_validate_call_params_not_supported():
# Patch supports_response_schema to simulate an unsupported model.
with patch("crewai.llm.supports_response_schema", return_value=False):
llm = LLM(
model="gemini/gemini-1.5-pro",
response_format=DummyResponse,
is_litellm=True,
)
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
with pytest.raises(ValueError) as excinfo:
llm._validate_call_params()
assert "does not support response_format" in str(excinfo.value)
@@ -706,16 +702,13 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
assert formatted == original_messages
def test_native_provider_raises_error_when_supported_but_fails():
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
# Mock that provider exists but throws an error when instantiated
mock_provider = MagicMock()
mock_provider.side_effect = ValueError(
"Native provider initialization failed"
)
mock_provider.side_effect = ValueError("Native provider initialization failed")
mock_get_native.return_value = mock_provider
with pytest.raises(ImportError) as excinfo:
@@ -758,16 +751,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
def test_prefixed_models_with_invalid_constants_use_litellm():
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
assert llm.is_litellm is True
assert llm.model == "openai/gemini-2.5-flash"
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
# Test openai/ prefix with unknown future model → LiteLLM
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
assert llm2.is_litellm is True
assert llm2.model == "openai/custom-finetune-model"
assert llm2.model == "openai/gpt-future-6"
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
@@ -775,21 +768,6 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
assert llm3.model == "anthropic/gpt-4o"
def test_prefixed_models_with_valid_patterns_use_native_sdk():
"""Test that models matching provider patterns use native SDK even if not in constants."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
assert llm.is_litellm is False
assert llm.provider == "openai"
assert llm.model == "gpt-future-6"
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
assert llm2.is_litellm is False
assert llm2.provider == "anthropic"
assert llm2.model == "claude-future-5"
def test_prefixed_models_with_non_native_providers_use_litellm():
"""Test that models with non-native provider prefixes always use LiteLLM."""
# Test groq/ prefix (not a native provider) → LiteLLM
@@ -843,36 +821,19 @@ def test_validate_model_in_constants():
"""Test the _validate_model_in_constants method."""
# OpenAI models
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
# Anthropic models
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
assert (
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
)
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
# Gemini models
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
# Azure models
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
# Bedrock models
assert (
LLM._validate_model_in_constants(
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
)
is True
)
assert (
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
is True
)
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True

View File

@@ -272,99 +272,6 @@ def another_simple_tool():
return "Hi!"
class TestAsyncDecoratorSupport:
"""Tests for async method support in @agent, @task decorators."""
def test_async_agent_memoization(self):
"""Async agent methods should be properly memoized."""
class AsyncAgentCrew:
call_count = 0
@agent
async def async_agent(self):
AsyncAgentCrew.call_count += 1
return Agent(
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
)
crew = AsyncAgentCrew()
first_call = crew.async_agent()
second_call = crew.async_agent()
assert first_call is second_call, "Async agent memoization failed"
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
def test_async_task_memoization(self):
"""Async task methods should be properly memoized."""
class AsyncTaskCrew:
call_count = 0
@task
async def async_task(self):
AsyncTaskCrew.call_count += 1
return Task(
description="Async Description", expected_output="Async Output"
)
crew = AsyncTaskCrew()
first_call = crew.async_task()
second_call = crew.async_task()
assert first_call is second_call, "Async task memoization failed"
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
def test_async_task_name_inference(self):
"""Async task should have name inferred from method name."""
class AsyncTaskNameCrew:
@task
async def my_async_task(self):
return Task(
description="Async Description", expected_output="Async Output"
)
crew = AsyncTaskNameCrew()
task_instance = crew.my_async_task()
assert task_instance.name == "my_async_task", (
"Async task name not inferred correctly"
)
def test_async_agent_returns_agent_not_coroutine(self):
"""Async agent decorator should return Agent, not coroutine."""
class AsyncAgentTypeCrew:
@agent
async def typed_async_agent(self):
return Agent(
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
)
crew = AsyncAgentTypeCrew()
result = crew.typed_async_agent()
assert isinstance(result, Agent), (
f"Expected Agent, got {type(result).__name__}"
)
def test_async_task_returns_task_not_coroutine(self):
"""Async task decorator should return Task, not coroutine."""
class AsyncTaskTypeCrew:
@task
async def typed_async_task(self):
return Task(
description="Typed Description", expected_output="Typed Output"
)
crew = AsyncTaskTypeCrew()
result = crew.typed_async_task()
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
def test_internal_crew_with_mcp():
from crewai_tools.adapters.tool_collection import ToolCollection

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.6.1"
__version__ = "1.6.0"