Compare commits

..

4 Commits

Author SHA1 Message Date
Lorenze Jay
bc4e6a3127 feat: bump versions to 1.6.1 (#3993)
Some checks failed
CodeQL Advanced / Analyze (actions) (push) Has been cancelled
CodeQL Advanced / Analyze (python) (push) Has been cancelled
Notify Downstream / notify-downstream (push) Has been cancelled
Mark stale issues and pull requests / stale (push) Has been cancelled
* feat: bump versions to 1.6.1

* chore: update crewAI dependency version to 1.6.1 in project templates
2025-11-28 17:57:15 -08:00
Vidit Ostwal
37526c693b Fixing ChatCompletionsClinet call (#3910)
* Fixing ChatCompletionsClinet call

* Moving from json-object -> JsonSchemaFormat

* Regex handling

* Adding additionalProperties explicitly

* fix: ensure additionalProperties is recursive

---------

Co-authored-by: Greyson LaLonde <greyson.r.lalonde@gmail.com>
Co-authored-by: Lorenze Jay <63378463+lorenzejay@users.noreply.github.com>
2025-11-28 17:33:53 -08:00
Greyson LaLonde
c59173a762 fix: ensure async methods are executable for annotations 2025-11-28 19:54:40 -05:00
Lorenze Jay
4d8eec96e8 refactor: enhance model validation and provider inference in LLM class (#3976)
* refactor: enhance model validation and provider inference in LLM class

- Updated the model validation logic to support pattern matching for new models and "latest" versions, improving flexibility for various providers.
- Refactored the `_validate_model_in_constants` method to first check hardcoded constants and then fall back to pattern matching.
- Introduced `_matches_provider_pattern` to streamline provider-specific model checks.
- Enhanced the `_infer_provider_from_model` method to utilize pattern matching for better provider inference.

This refactor aims to improve the extensibility of the LLM class, allowing it to accommodate new models without requiring constant updates to the hardcoded lists.

* feat: add new Anthropic model versions to constants

- Introduced "claude-opus-4-5-20251101" and "claude-opus-4-5" to the AnthropicModels and ANTHROPIC_MODELS lists for enhanced model support.
- Added "anthropic.claude-opus-4-5-20251101-v1:0" to BedrockModels and BEDROCK_MODELS to ensure compatibility with the latest model offerings.
- Updated test cases to ensure proper environment variable handling for model validation, improving robustness in testing scenarios.

* dont infer this way - dropped
2025-11-28 13:54:40 -08:00
21 changed files with 332 additions and 476 deletions

View File

@@ -12,7 +12,7 @@ dependencies = [
"pytube>=15.0.0",
"requests>=2.32.5",
"docker>=7.1.0",
"crewai==1.6.0",
"crewai==1.6.1",
"lancedb>=0.5.4",
"tiktoken>=0.8.0",
"beautifulsoup4>=4.13.4",

View File

@@ -291,4 +291,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.6.0"
__version__ = "1.6.1"

View File

@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.6.0",
"crewai-tools==1.6.1",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.6.0"
__version__ = "1.6.1"
_telemetry_submitted = False

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.0"
"crewai[tools]==1.6.1"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.0"
"crewai[tools]==1.6.1"
]
[project.scripts]

View File

@@ -406,46 +406,100 @@ class LLM(BaseLLM):
instance.is_litellm = True
return instance
@classmethod
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
"""Check if a model name matches provider-specific patterns.
This allows supporting models that aren't in the hardcoded constants list,
including "latest" versions and new models that follow provider naming conventions.
Args:
model: The model name to check
provider: The provider to check against (canonical name)
Returns:
True if the model matches the provider's naming pattern, False otherwise
"""
model_lower = model.lower()
if provider == "openai":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
)
if provider == "anthropic" or provider == "claude":
return any(
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
)
if provider == "gemini" or provider == "google":
return any(
model_lower.startswith(prefix)
for prefix in ["gemini-", "gemma-", "learnlm-"]
)
if provider == "bedrock":
return "." in model_lower
if provider == "azure":
return any(
model_lower.startswith(prefix)
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
)
return False
@classmethod
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
"""Validate if a model name exists in the provider's constants.
"""Validate if a model name exists in the provider's constants or matches provider patterns.
This method first checks the hardcoded constants list for known models.
If not found, it falls back to pattern matching to support new models,
"latest" versions, and models that follow provider naming conventions.
Args:
model: The model name to validate
provider: The provider to check against (canonical name)
Returns:
True if the model exists in the provider's constants, False otherwise
True if the model exists in constants or matches provider patterns, False otherwise
"""
if provider == "openai":
return model in OPENAI_MODELS
if provider == "openai" and model in OPENAI_MODELS:
return True
if provider == "anthropic" or provider == "claude":
return model in ANTHROPIC_MODELS
if (
provider == "anthropic" or provider == "claude"
) and model in ANTHROPIC_MODELS:
return True
if provider == "gemini":
return model in GEMINI_MODELS
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
return True
if provider == "bedrock":
return model in BEDROCK_MODELS
if provider == "bedrock" and model in BEDROCK_MODELS:
return True
if provider == "azure":
# azure does not provide a list of available models, determine a better way to handle this
return True
return False
# Fallback to pattern matching for models not in constants
return cls._matches_provider_pattern(model, provider)
@classmethod
def _infer_provider_from_model(cls, model: str) -> str:
"""Infer the provider from the model name.
This method first checks the hardcoded constants list for known models.
If not found, it uses pattern matching to infer the provider from model name patterns.
This allows supporting new models and "latest" versions without hardcoding.
Args:
model: The model name without provider prefix
Returns:
The inferred provider name, defaults to "openai"
"""
if model in OPENAI_MODELS:
return "openai"
@@ -1699,12 +1753,14 @@ class LLM(BaseLLM):
max_tokens=self.max_tokens,
presence_penalty=self.presence_penalty,
frequency_penalty=self.frequency_penalty,
logit_bias=copy.deepcopy(self.logit_bias, memo)
if self.logit_bias
else None,
response_format=copy.deepcopy(self.response_format, memo)
if self.response_format
else None,
logit_bias=(
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
),
response_format=(
copy.deepcopy(self.response_format, memo)
if self.response_format
else None
),
seed=self.seed,
logprobs=self.logprobs,
top_logprobs=self.top_logprobs,

View File

@@ -316,33 +316,11 @@ class BaseLLM(ABC):
from_task: Task | None = None,
from_agent: Agent | None = None,
tool_call: dict[str, Any] | None = None,
call_type: LLMCallType | None = None,
) -> None:
"""Emit stream chunk event.
Args:
chunk: The text content of the chunk
from_task: Optional task that initiated the call
from_agent: Optional agent that initiated the call
tool_call: Optional tool call information as a dict with keys:
- id: Tool call ID
- function: Dict with 'name' and 'arguments'
- type: Tool call type (e.g., 'function')
- index: Index of the tool call
call_type: Optional call type. If not provided, it will be inferred
from the presence of tool_call (TOOL_CALL if tool_call is present,
LLM_CALL otherwise)
"""
"""Emit stream chunk event."""
if not hasattr(crewai_event_bus, "emit"):
raise ValueError("crewai_event_bus does not have an emit method") from None
# Infer call_type from tool_call presence if not explicitly provided
effective_call_type = call_type
if effective_call_type is None:
effective_call_type = (
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
)
crewai_event_bus.emit(
self,
event=LLMStreamChunkEvent(
@@ -350,7 +328,6 @@ class BaseLLM(ABC):
tool_call=tool_call,
from_task=from_task,
from_agent=from_agent,
call_type=effective_call_type,
),
)

View File

@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
AnthropicModels: TypeAlias = Literal[
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
"claude-3-haiku-20240307",
]
ANTHROPIC_MODELS: list[AnthropicModels] = [
"claude-opus-4-5-20251101",
"claude-opus-4-5",
"claude-3-7-sonnet-latest",
"claude-3-7-sonnet-20250219",
"claude-3-5-haiku-latest",
@@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",
@@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
"anthropic.claude-haiku-4-5-20251001-v1:0",
"anthropic.claude-instant-v1:2:100k",
"anthropic.claude-opus-4-5-20251101-v1:0",
"anthropic.claude-opus-4-1-20250805-v1:0",
"anthropic.claude-opus-4-20250514-v1:0",
"anthropic.claude-sonnet-4-20250514-v1:0",

View File

@@ -450,14 +450,9 @@ class AnthropicCompletion(BaseLLM):
# (the SDK sets it internally)
stream_params = {k: v for k, v in params.items() if k != "stream"}
# Track tool use blocks during streaming
current_tool_use: dict[str, Any] = {}
tool_use_index = 0
# Make streaming API call
with self.client.messages.stream(**stream_params) as stream:
for event in stream:
# Handle text content
if hasattr(event, "delta") and hasattr(event.delta, "text"):
text_delta = event.delta.text
full_response += text_delta
@@ -467,55 +462,6 @@ class AnthropicCompletion(BaseLLM):
from_agent=from_agent,
)
# Handle tool use start (content_block_start event with tool_use type)
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
if event.content_block.type == "tool_use":
current_tool_use = {
"id": getattr(event.content_block, "id", None),
"name": getattr(event.content_block, "name", ""),
"input": "",
"index": tool_use_index,
}
tool_use_index += 1
# Emit tool call start event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": "",
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle tool use input delta (input_json events)
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
partial_json = event.delta.partial_json
if current_tool_use and partial_json:
current_tool_use["input"] += partial_json
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use["id"],
"function": {
"name": current_tool_use["name"],
"arguments": partial_json,
},
"type": "function",
"index": current_tool_use["index"],
}
self._emit_stream_chunk_event(
chunk=partial_json,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
final_message: Message = stream.get_final_message()
usage = self._extract_anthropic_token_usage(final_message)

View File

@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.converter import generate_model_description
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
)
@@ -26,6 +27,7 @@ try:
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
JsonSchemaFormat,
StreamingChatCompletionsUpdate,
)
from azure.core.credentials import (
@@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM):
}
if response_model and self.is_openai_model:
params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": response_model.__name__,
"schema": response_model.model_json_schema(),
},
}
model_description = generate_model_description(response_model)
json_schema_info = model_description["json_schema"]
json_schema_name = json_schema_info["name"]
params["response_format"] = JsonSchemaFormat(
name=json_schema_name,
schema=json_schema_info["schema"],
description=f"Schema for {json_schema_name}",
strict=json_schema_info["strict"],
)
# Only include model parameter for non-Azure OpenAI endpoints
# Azure OpenAI endpoints have the deployment name in the URL
@@ -311,8 +316,8 @@ class AzureCompletion(BaseLLM):
params["tool_choice"] = "auto"
additional_params = self.additional_params
additional_drop_params = additional_params.get('additional_drop_params')
drop_params = additional_params.get('drop_params')
additional_drop_params = additional_params.get("additional_drop_params")
drop_params = additional_params.get("drop_params")
if drop_params and isinstance(additional_drop_params, list):
for drop_param in additional_drop_params:
@@ -503,10 +508,8 @@ class AzureCompletion(BaseLLM):
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"id": call_id,
"name": "",
"arguments": "",
"index": getattr(tool_call, "index", 0) or 0,
}
if tool_call.function and tool_call.function.name:
@@ -516,23 +519,6 @@ class AzureCompletion(BaseLLM):
tool_call.function.arguments
)
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed tool calls
if tool_calls and available_functions:
for call_data in tool_calls.values():

View File

@@ -567,31 +567,12 @@ class BedrockCompletion(BaseLLM):
elif "contentBlockStart" in event:
start = event["contentBlockStart"].get("start", {})
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
if "toolUse" in start:
current_tool_use = start["toolUse"]
current_tool_use["_block_index"] = block_index
current_tool_use["_accumulated_input"] = ""
tool_use_id = current_tool_use.get("toolUseId")
logging.debug(
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
)
# Emit tool call start event
tool_call_event_data = {
"id": tool_use_id,
"function": {
"name": current_tool_use.get("name", ""),
"arguments": "",
},
"type": "function",
"index": block_index,
}
self._emit_stream_chunk_event(
chunk="",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
elif "contentBlockDelta" in event:
delta = event["contentBlockDelta"]["delta"]
@@ -608,23 +589,6 @@ class BedrockCompletion(BaseLLM):
tool_input = delta["toolUse"].get("input", "")
if tool_input:
logging.debug(f"Tool input delta: {tool_input}")
current_tool_use["_accumulated_input"] += tool_input
# Emit tool call delta event
tool_call_event_data = {
"id": current_tool_use.get("toolUseId"),
"function": {
"name": current_tool_use.get("name", ""),
"arguments": tool_input,
},
"type": "function",
"index": current_tool_use.get("_block_index", 0),
}
self._emit_stream_chunk_event(
chunk=tool_input,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Content block stop - end of a content block
elif "contentBlockStop" in event:

View File

@@ -1,4 +1,3 @@
import json
import logging
import os
import re
@@ -497,7 +496,7 @@ class GeminiCompletion(BaseLLM):
if hasattr(chunk, "candidates") and chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part_index, part in enumerate(candidate.content.parts):
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
@@ -506,27 +505,8 @@ class GeminiCompletion(BaseLLM):
"args": dict(part.function_call.args)
if part.function_call.args
else {},
"index": part_index,
}
# Emit tool call streaming event
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
tool_call_event_data = {
"id": call_id,
"function": {
"name": function_calls[call_id]["name"],
"arguments": args_str,
},
"type": "function",
"index": function_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=args_str,
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
# Handle completed function calls
if function_calls and available_functions:
for call_data in function_calls.values():

View File

@@ -510,10 +510,8 @@ class OpenAICompletion(BaseLLM):
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"id": call_id,
"name": "",
"arguments": "",
"index": tool_call.index if tool_call.index is not None else 0,
}
if tool_call.function and tool_call.function.name:
@@ -521,23 +519,6 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
# Emit tool call streaming event
tool_call_event_data = {
"id": tool_calls[call_id]["id"],
"function": {
"name": tool_calls[call_id]["name"],
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
},
"type": "function",
"index": tool_calls[call_id]["index"],
}
self._emit_stream_chunk_event(
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
from_task=from_task,
from_agent=from_agent,
tool_call=tool_call_event_data,
)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import wraps
import inspect
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
from crewai.project.utils import memoize
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
return CacheHandlerMethod(memoize(meth))
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
"""Call a method, awaiting it if async and running in an event loop."""
result = method(*args, **kwargs)
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
@overload
def crew(
meth: Callable[Concatenate[SelfT, P], Crew],
@@ -198,7 +217,7 @@ def crew(
# Instantiate tasks in order
for _, task_method in tasks:
task_instance = task_method(self)
task_instance = _call_method(task_method, self)
instantiated_tasks.append(task_instance)
agent_instance = getattr(task_instance, "agent", None)
if agent_instance and agent_instance.role not in agent_roles:
@@ -207,7 +226,7 @@ def crew(
# Instantiate agents not included by tasks
for _, agent_method in agents:
agent_instance = agent_method(self)
agent_instance = _call_method(agent_method, self)
if agent_instance.role not in agent_roles:
instantiated_agents.append(agent_instance)
agent_roles.add(agent_instance.role)
@@ -215,7 +234,7 @@ def crew(
self.agents = instantiated_agents
self.tasks = instantiated_tasks
crew_instance = meth(self, *args, **kwargs)
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
def callback_wrapper(
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance

View File

@@ -1,7 +1,8 @@
"""Utility functions for the crewai project module."""
from collections.abc import Callable
from collections.abc import Callable, Coroutine
from functools import wraps
import inspect
from typing import Any, ParamSpec, TypeVar, cast
from pydantic import BaseModel
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a method by caching its results based on arguments.
Handles Pydantic BaseModel instances by converting them to JSON strings
before hashing for cache lookup.
Handles both sync and async methods. Pydantic BaseModel instances are
converted to JSON strings before hashing for cache lookup.
Args:
meth: The method to memoize.
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
Returns:
A memoized version of the method that caches results.
"""
if inspect.iscoroutinefunction(meth):
return cast(Callable[P, R], _memoize_async(meth))
return _memoize_sync(meth)
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
"""Memoize a synchronous method."""
@wraps(meth)
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
"""Wrapper that converts arguments to hashable form before caching.
Args:
*args: Positional arguments to the memoized method.
**kwargs: Keyword arguments to the memoized method.
Returns:
The result of the memoized method call.
"""
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
return result
return cast(Callable[P, R], wrapper)
def _memoize_async(
meth: Callable[P, Coroutine[Any, Any, R]],
) -> Callable[P, Coroutine[Any, Any, R]]:
"""Memoize an async method."""
@wraps(meth)
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
hashable_args = tuple(_make_hashable(arg) for arg in args)
hashable_kwargs = tuple(
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
)
cache_key = str((hashable_args, hashable_kwargs))
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
if cached_result is not None:
return cached_result
result = await meth(*args, **kwargs)
cache.add(tool=meth.__name__, input=cache_key, output=result)
return result
return wrapper

View File

@@ -2,8 +2,10 @@
from __future__ import annotations
import asyncio
from collections.abc import Callable
from functools import partial
import inspect
from pathlib import Path
from typing import (
TYPE_CHECKING,
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
crew: Callable[..., Crew]
def _resolve_result(result: Any) -> Any:
"""Resolve a potentially async result to its value."""
if inspect.iscoroutine(result):
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = None
if loop and loop.is_running():
import concurrent.futures
with concurrent.futures.ThreadPoolExecutor() as pool:
return pool.submit(asyncio.run, result).result()
return asyncio.run(result)
return result
class DecoratedMethod(Generic[P, R]):
"""Base wrapper for methods with decorator metadata.
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
"""
if obj is None:
return self
bound = partial(self._meth, obj)
inner = partial(self._meth, obj)
def _bound(*args: Any, **kwargs: Any) -> R:
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
return result
for attr in (
"is_agent",
"is_llm",
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
"is_crew",
):
if hasattr(self, attr):
setattr(bound, attr, getattr(self, attr))
return bound
setattr(_bound, attr, getattr(self, attr))
return _bound
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Call the wrapped method.
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
The task result with name ensured.
"""
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
result = _resolve_result(result)
return self._task_method.ensure_task_name(result)
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
Returns:
The task instance with name set if not already provided.
"""
return self.ensure_task_name(self._meth(*args, **kwargs))
result = self._meth(*args, **kwargs)
result = _resolve_result(result)
return self.ensure_task_name(result)
def unwrap(self) -> Callable[P, TaskResultT]:
"""Get the original unwrapped method.

View File

@@ -243,7 +243,11 @@ def test_validate_call_params_not_supported():
# Patch supports_response_schema to simulate an unsupported model.
with patch("crewai.llm.supports_response_schema", return_value=False):
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
llm = LLM(
model="gemini/gemini-1.5-pro",
response_format=DummyResponse,
is_litellm=True,
)
with pytest.raises(ValueError) as excinfo:
llm._validate_call_params()
assert "does not support response_format" in str(excinfo.value)
@@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
assert formatted == original_messages
def test_native_provider_raises_error_when_supported_but_fails():
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
# Mock that provider exists but throws an error when instantiated
mock_provider = MagicMock()
mock_provider.side_effect = ValueError("Native provider initialization failed")
mock_provider.side_effect = ValueError(
"Native provider initialization failed"
)
mock_get_native.return_value = mock_provider
with pytest.raises(ImportError) as excinfo:
@@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
def test_prefixed_models_with_invalid_constants_use_litellm():
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
assert llm.is_litellm is True
assert llm.model == "openai/gemini-2.5-flash"
# Test openai/ prefix with unknown future model → LiteLLM
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
assert llm2.is_litellm is True
assert llm2.model == "openai/gpt-future-6"
assert llm2.model == "openai/custom-finetune-model"
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
@@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
assert llm3.model == "anthropic/gpt-4o"
def test_prefixed_models_with_valid_patterns_use_native_sdk():
"""Test that models matching provider patterns use native SDK even if not in constants."""
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
assert llm.is_litellm is False
assert llm.provider == "openai"
assert llm.model == "gpt-future-6"
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
assert llm2.is_litellm is False
assert llm2.provider == "anthropic"
assert llm2.model == "claude-future-5"
def test_prefixed_models_with_non_native_providers_use_litellm():
"""Test that models with non-native provider prefixes always use LiteLLM."""
# Test groq/ prefix (not a native provider) → LiteLLM
@@ -821,19 +843,36 @@ def test_validate_model_in_constants():
"""Test the _validate_model_in_constants method."""
# OpenAI models
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
# Anthropic models
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
assert (
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
)
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
# Gemini models
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
# Azure models
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
# Bedrock models
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
assert (
LLM._validate_model_in_constants(
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
)
is True
)
assert (
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
is True
)

View File

@@ -272,6 +272,99 @@ def another_simple_tool():
return "Hi!"
class TestAsyncDecoratorSupport:
"""Tests for async method support in @agent, @task decorators."""
def test_async_agent_memoization(self):
"""Async agent methods should be properly memoized."""
class AsyncAgentCrew:
call_count = 0
@agent
async def async_agent(self):
AsyncAgentCrew.call_count += 1
return Agent(
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
)
crew = AsyncAgentCrew()
first_call = crew.async_agent()
second_call = crew.async_agent()
assert first_call is second_call, "Async agent memoization failed"
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
def test_async_task_memoization(self):
"""Async task methods should be properly memoized."""
class AsyncTaskCrew:
call_count = 0
@task
async def async_task(self):
AsyncTaskCrew.call_count += 1
return Task(
description="Async Description", expected_output="Async Output"
)
crew = AsyncTaskCrew()
first_call = crew.async_task()
second_call = crew.async_task()
assert first_call is second_call, "Async task memoization failed"
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
def test_async_task_name_inference(self):
"""Async task should have name inferred from method name."""
class AsyncTaskNameCrew:
@task
async def my_async_task(self):
return Task(
description="Async Description", expected_output="Async Output"
)
crew = AsyncTaskNameCrew()
task_instance = crew.my_async_task()
assert task_instance.name == "my_async_task", (
"Async task name not inferred correctly"
)
def test_async_agent_returns_agent_not_coroutine(self):
"""Async agent decorator should return Agent, not coroutine."""
class AsyncAgentTypeCrew:
@agent
async def typed_async_agent(self):
return Agent(
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
)
crew = AsyncAgentTypeCrew()
result = crew.typed_async_agent()
assert isinstance(result, Agent), (
f"Expected Agent, got {type(result).__name__}"
)
def test_async_task_returns_task_not_coroutine(self):
"""Async task decorator should return Task, not coroutine."""
class AsyncTaskTypeCrew:
@task
async def typed_async_task(self):
return Task(
description="Typed Description", expected_output="Typed Output"
)
crew = AsyncTaskTypeCrew()
result = crew.typed_async_task()
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
def test_internal_crew_with_mcp():
from crewai_tools.adapters.tool_collection import ToolCollection

View File

@@ -715,243 +715,3 @@ class TestStreamingImports:
assert StreamChunk is not None
assert StreamChunkType is not None
assert ToolCallChunk is not None
class TestLLMStreamChunkEventToolCall:
"""Tests for LLMStreamChunkEvent with tool call information."""
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
# Create a tool call event
tool_call = ToolCall(
id="call-123",
function=FunctionCall(
name="search",
arguments='{"query": "test"}',
),
type="function",
index=0,
)
event = LLMStreamChunkEvent(
chunk='{"query": "test"}',
tool_call=tool_call,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"query": "test"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-123"
assert event.tool_call.function.name == "search"
assert event.tool_call.function.arguments == '{"query": "test"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
# Create a tool call event using dict (as providers emit)
tool_call_dict = {
"id": "call-456",
"function": {
"name": "get_weather",
"arguments": '{"location": "NYC"}',
},
"type": "function",
"index": 1,
}
event = LLMStreamChunkEvent(
chunk='{"location": "NYC"}',
tool_call=tool_call_dict,
call_type=LLMCallType.TOOL_CALL,
)
assert event.chunk == '{"location": "NYC"}'
assert event.tool_call is not None
assert event.tool_call.id == "call-456"
assert event.tool_call.function.name == "get_weather"
assert event.tool_call.function.arguments == '{"location": "NYC"}'
assert event.call_type == LLMCallType.TOOL_CALL
def test_llm_stream_chunk_event_text_only(self) -> None:
"""Test that LLMStreamChunkEvent works for text-only chunks."""
from crewai.events.types.llm_events import (
LLMCallType,
LLMStreamChunkEvent,
)
event = LLMStreamChunkEvent(
chunk="Hello, world!",
tool_call=None,
call_type=LLMCallType.LLM_CALL,
)
assert event.chunk == "Hello, world!"
assert event.tool_call is None
assert event.call_type == LLMCallType.LLM_CALL
class TestBaseLLMEmitStreamChunkEvent:
"""Tests for BaseLLM._emit_stream_chunk_event method."""
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
from unittest.mock import MagicMock, patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with tool_call - should infer TOOL_CALL type
tool_call_dict = {
"id": "call-789",
"function": {
"name": "test_tool",
"arguments": '{"arg": "value"}',
},
"type": "function",
"index": 0,
}
llm._emit_stream_chunk_event(
chunk='{"arg": "value"}',
tool_call=tool_call_dict,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
assert captured_events[0].tool_call is not None
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit without tool_call - should infer LLM_CALL type
llm._emit_stream_chunk_event(
chunk="Hello, world!",
tool_call=None,
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.LLM_CALL
assert captured_events[0].tool_call is None
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
from unittest.mock import patch
from crewai.llms.base_llm import BaseLLM
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
# Create a mock BaseLLM instance
with patch.object(BaseLLM, "__abstractmethods__", set()):
llm = BaseLLM(model="test-model") # type: ignore
captured_events: list[LLMStreamChunkEvent] = []
def capture_emit(source: Any, event: Any) -> None:
if isinstance(event, LLMStreamChunkEvent):
captured_events.append(event)
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
mock_bus.emit = capture_emit
# Emit with explicit call_type - should use provided type
llm._emit_stream_chunk_event(
chunk="test",
tool_call=None,
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
)
assert len(captured_events) == 1
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
class TestStreamingToolCallExtraction:
"""Tests for tool call extraction from streaming events."""
def test_extract_tool_call_info_from_event(self) -> None:
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import (
LLMStreamChunkEvent,
ToolCall,
FunctionCall,
)
from crewai.types.streaming import StreamChunkType
# Create event with tool call
tool_call = ToolCall(
id="call-extract-test",
function=FunctionCall(
name="extract_test",
arguments='{"key": "value"}',
),
type="function",
index=2,
)
event = LLMStreamChunkEvent(
chunk='{"key": "value"}',
tool_call=tool_call,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TOOL_CALL
assert tool_call_chunk is not None
assert tool_call_chunk.tool_id == "call-extract-test"
assert tool_call_chunk.tool_name == "extract_test"
assert tool_call_chunk.arguments == '{"key": "value"}'
assert tool_call_chunk.index == 2
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
"""Test that TEXT type is returned when no tool call is present."""
from crewai.utilities.streaming import _extract_tool_call_info
from crewai.events.types.llm_events import LLMStreamChunkEvent
from crewai.types.streaming import StreamChunkType
event = LLMStreamChunkEvent(
chunk="Just text content",
tool_call=None,
)
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
assert chunk_type == StreamChunkType.TEXT
assert tool_call_chunk is None

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.6.0"
__version__ = "1.6.1"