mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-06 22:58:30 +00:00
Compare commits
4 Commits
devin/1764
...
1.6.1
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
bc4e6a3127 | ||
|
|
37526c693b | ||
|
|
c59173a762 | ||
|
|
4d8eec96e8 |
@@ -12,7 +12,7 @@ dependencies = [
|
||||
"pytube>=15.0.0",
|
||||
"requests>=2.32.5",
|
||||
"docker>=7.1.0",
|
||||
"crewai==1.6.0",
|
||||
"crewai==1.6.1",
|
||||
"lancedb>=0.5.4",
|
||||
"tiktoken>=0.8.0",
|
||||
"beautifulsoup4>=4.13.4",
|
||||
|
||||
@@ -291,4 +291,4 @@ __all__ = [
|
||||
"ZapierActionTools",
|
||||
]
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
@@ -48,7 +48,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
|
||||
|
||||
[project.optional-dependencies]
|
||||
tools = [
|
||||
"crewai-tools==1.6.0",
|
||||
"crewai-tools==1.6.1",
|
||||
]
|
||||
embeddings = [
|
||||
"tiktoken~=0.8.0"
|
||||
|
||||
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.0"
|
||||
"crewai[tools]==1.6.1"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -406,46 +406,100 @@ class LLM(BaseLLM):
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
"""Check if a model name matches provider-specific patterns.
|
||||
|
||||
This allows supporting models that aren't in the hardcoded constants list,
|
||||
including "latest" versions and new models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to check
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model matches the provider's naming pattern, False otherwise
|
||||
"""
|
||||
model_lower = model.lower()
|
||||
|
||||
if provider == "openai":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "o1", "o3", "o4", "whisper-"]
|
||||
)
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return any(
|
||||
model_lower.startswith(prefix) for prefix in ["claude-", "anthropic."]
|
||||
)
|
||||
|
||||
if provider == "gemini" or provider == "google":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gemini-", "gemma-", "learnlm-"]
|
||||
)
|
||||
|
||||
if provider == "bedrock":
|
||||
return "." in model_lower
|
||||
|
||||
if provider == "azure":
|
||||
return any(
|
||||
model_lower.startswith(prefix)
|
||||
for prefix in ["gpt-", "gpt-35-", "o1", "o3", "o4", "azure-"]
|
||||
)
|
||||
|
||||
return False
|
||||
|
||||
@classmethod
|
||||
def _validate_model_in_constants(cls, model: str, provider: str) -> bool:
|
||||
"""Validate if a model name exists in the provider's constants.
|
||||
"""Validate if a model name exists in the provider's constants or matches provider patterns.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it falls back to pattern matching to support new models,
|
||||
"latest" versions, and models that follow provider naming conventions.
|
||||
|
||||
Args:
|
||||
model: The model name to validate
|
||||
provider: The provider to check against (canonical name)
|
||||
|
||||
Returns:
|
||||
True if the model exists in the provider's constants, False otherwise
|
||||
True if the model exists in constants or matches provider patterns, False otherwise
|
||||
"""
|
||||
if provider == "openai":
|
||||
return model in OPENAI_MODELS
|
||||
if provider == "openai" and model in OPENAI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "anthropic" or provider == "claude":
|
||||
return model in ANTHROPIC_MODELS
|
||||
if (
|
||||
provider == "anthropic" or provider == "claude"
|
||||
) and model in ANTHROPIC_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "gemini":
|
||||
return model in GEMINI_MODELS
|
||||
if (provider == "gemini" or provider == "google") and model in GEMINI_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "bedrock":
|
||||
return model in BEDROCK_MODELS
|
||||
if provider == "bedrock" and model in BEDROCK_MODELS:
|
||||
return True
|
||||
|
||||
if provider == "azure":
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
return False
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
def _infer_provider_from_model(cls, model: str) -> str:
|
||||
"""Infer the provider from the model name.
|
||||
|
||||
This method first checks the hardcoded constants list for known models.
|
||||
If not found, it uses pattern matching to infer the provider from model name patterns.
|
||||
This allows supporting new models and "latest" versions without hardcoding.
|
||||
|
||||
Args:
|
||||
model: The model name without provider prefix
|
||||
|
||||
Returns:
|
||||
The inferred provider name, defaults to "openai"
|
||||
"""
|
||||
|
||||
if model in OPENAI_MODELS:
|
||||
return "openai"
|
||||
|
||||
@@ -1699,12 +1753,14 @@ class LLM(BaseLLM):
|
||||
max_tokens=self.max_tokens,
|
||||
presence_penalty=self.presence_penalty,
|
||||
frequency_penalty=self.frequency_penalty,
|
||||
logit_bias=copy.deepcopy(self.logit_bias, memo)
|
||||
if self.logit_bias
|
||||
else None,
|
||||
response_format=copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None,
|
||||
logit_bias=(
|
||||
copy.deepcopy(self.logit_bias, memo) if self.logit_bias else None
|
||||
),
|
||||
response_format=(
|
||||
copy.deepcopy(self.response_format, memo)
|
||||
if self.response_format
|
||||
else None
|
||||
),
|
||||
seed=self.seed,
|
||||
logprobs=self.logprobs,
|
||||
top_logprobs=self.top_logprobs,
|
||||
|
||||
@@ -316,33 +316,11 @@ class BaseLLM(ABC):
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
tool_call: dict[str, Any] | None = None,
|
||||
call_type: LLMCallType | None = None,
|
||||
) -> None:
|
||||
"""Emit stream chunk event.
|
||||
|
||||
Args:
|
||||
chunk: The text content of the chunk
|
||||
from_task: Optional task that initiated the call
|
||||
from_agent: Optional agent that initiated the call
|
||||
tool_call: Optional tool call information as a dict with keys:
|
||||
- id: Tool call ID
|
||||
- function: Dict with 'name' and 'arguments'
|
||||
- type: Tool call type (e.g., 'function')
|
||||
- index: Index of the tool call
|
||||
call_type: Optional call type. If not provided, it will be inferred
|
||||
from the presence of tool_call (TOOL_CALL if tool_call is present,
|
||||
LLM_CALL otherwise)
|
||||
"""
|
||||
"""Emit stream chunk event."""
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError("crewai_event_bus does not have an emit method") from None
|
||||
|
||||
# Infer call_type from tool_call presence if not explicitly provided
|
||||
effective_call_type = call_type
|
||||
if effective_call_type is None:
|
||||
effective_call_type = (
|
||||
LLMCallType.TOOL_CALL if tool_call is not None else LLMCallType.LLM_CALL
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMStreamChunkEvent(
|
||||
@@ -350,7 +328,6 @@ class BaseLLM(ABC):
|
||||
tool_call=tool_call,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
call_type=effective_call_type,
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
@@ -182,6 +182,8 @@ OPENAI_MODELS: list[OpenAIModels] = [
|
||||
|
||||
|
||||
AnthropicModels: TypeAlias = Literal[
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -208,6 +210,8 @@ AnthropicModels: TypeAlias = Literal[
|
||||
"claude-3-haiku-20240307",
|
||||
]
|
||||
ANTHROPIC_MODELS: list[AnthropicModels] = [
|
||||
"claude-opus-4-5-20251101",
|
||||
"claude-opus-4-5",
|
||||
"claude-3-7-sonnet-latest",
|
||||
"claude-3-7-sonnet-20250219",
|
||||
"claude-3-5-haiku-latest",
|
||||
@@ -452,6 +456,7 @@ BedrockModels: TypeAlias = Literal[
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
@@ -524,6 +529,7 @@ BEDROCK_MODELS: list[BedrockModels] = [
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0:28k",
|
||||
"anthropic.claude-haiku-4-5-20251001-v1:0",
|
||||
"anthropic.claude-instant-v1:2:100k",
|
||||
"anthropic.claude-opus-4-5-20251101-v1:0",
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0",
|
||||
"anthropic.claude-opus-4-20250514-v1:0",
|
||||
"anthropic.claude-sonnet-4-20250514-v1:0",
|
||||
|
||||
@@ -450,14 +450,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Track tool use blocks during streaming
|
||||
current_tool_use: dict[str, Any] = {}
|
||||
tool_use_index = 0
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
# Handle text content
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -467,55 +462,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool use start (content_block_start event with tool_use type)
|
||||
if hasattr(event, "content_block") and hasattr(event.content_block, "type"):
|
||||
if event.content_block.type == "tool_use":
|
||||
current_tool_use = {
|
||||
"id": getattr(event.content_block, "id", None),
|
||||
"name": getattr(event.content_block, "name", ""),
|
||||
"input": "",
|
||||
"index": tool_use_index,
|
||||
}
|
||||
tool_use_index += 1
|
||||
# Emit tool call start event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use["id"],
|
||||
"function": {
|
||||
"name": current_tool_use["name"],
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle tool use input delta (input_json events)
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||
partial_json = event.delta.partial_json
|
||||
if current_tool_use and partial_json:
|
||||
current_tool_use["input"] += partial_json
|
||||
# Emit tool call delta event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use["id"],
|
||||
"function": {
|
||||
"name": current_tool_use["name"],
|
||||
"arguments": partial_json,
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=partial_json,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
|
||||
@@ -8,6 +8,7 @@ from typing import TYPE_CHECKING, Any
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
@@ -26,6 +27,7 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
from azure.core.credentials import (
|
||||
@@ -278,13 +280,16 @@ class AzureCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": response_model.__name__,
|
||||
"schema": response_model.model_json_schema(),
|
||||
},
|
||||
}
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
json_schema_name = json_schema_info["name"]
|
||||
|
||||
params["response_format"] = JsonSchemaFormat(
|
||||
name=json_schema_name,
|
||||
schema=json_schema_info["schema"],
|
||||
description=f"Schema for {json_schema_name}",
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
@@ -311,8 +316,8 @@ class AzureCompletion(BaseLLM):
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
additional_params = self.additional_params
|
||||
additional_drop_params = additional_params.get('additional_drop_params')
|
||||
drop_params = additional_params.get('drop_params')
|
||||
additional_drop_params = additional_params.get("additional_drop_params")
|
||||
drop_params = additional_params.get("drop_params")
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
@@ -503,10 +508,8 @@ class AzureCompletion(BaseLLM):
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"id": call_id,
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": getattr(tool_call, "index", 0) or 0,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -516,23 +519,6 @@ class AzureCompletion(BaseLLM):
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# Emit tool call streaming event
|
||||
tool_call_event_data = {
|
||||
"id": tool_calls[call_id]["id"],
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
|
||||
@@ -567,31 +567,12 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
elif "contentBlockStart" in event:
|
||||
start = event["contentBlockStart"].get("start", {})
|
||||
block_index = event["contentBlockStart"].get("contentBlockIndex", 0)
|
||||
if "toolUse" in start:
|
||||
current_tool_use = start["toolUse"]
|
||||
current_tool_use["_block_index"] = block_index
|
||||
current_tool_use["_accumulated_input"] = ""
|
||||
tool_use_id = current_tool_use.get("toolUseId")
|
||||
logging.debug(
|
||||
f"Tool use started in stream: {current_tool_use.get('name')} (ID: {tool_use_id})"
|
||||
)
|
||||
# Emit tool call start event
|
||||
tool_call_event_data = {
|
||||
"id": tool_use_id,
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": block_index,
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk="",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
elif "contentBlockDelta" in event:
|
||||
delta = event["contentBlockDelta"]["delta"]
|
||||
@@ -608,23 +589,6 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_input = delta["toolUse"].get("input", "")
|
||||
if tool_input:
|
||||
logging.debug(f"Tool input delta: {tool_input}")
|
||||
current_tool_use["_accumulated_input"] += tool_input
|
||||
# Emit tool call delta event
|
||||
tool_call_event_data = {
|
||||
"id": current_tool_use.get("toolUseId"),
|
||||
"function": {
|
||||
"name": current_tool_use.get("name", ""),
|
||||
"arguments": tool_input,
|
||||
},
|
||||
"type": "function",
|
||||
"index": current_tool_use.get("_block_index", 0),
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_input,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Content block stop - end of a content block
|
||||
elif "contentBlockStop" in event:
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
@@ -497,7 +496,7 @@ class GeminiCompletion(BaseLLM):
|
||||
if hasattr(chunk, "candidates") and chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part_index, part in enumerate(candidate.content.parts):
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
@@ -506,27 +505,8 @@ class GeminiCompletion(BaseLLM):
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
"index": part_index,
|
||||
}
|
||||
|
||||
# Emit tool call streaming event
|
||||
args_str = json.dumps(function_calls[call_id]["args"]) if function_calls[call_id]["args"] else ""
|
||||
tool_call_event_data = {
|
||||
"id": call_id,
|
||||
"function": {
|
||||
"name": function_calls[call_id]["name"],
|
||||
"arguments": args_str,
|
||||
},
|
||||
"type": "function",
|
||||
"index": function_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=args_str,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
|
||||
@@ -510,10 +510,8 @@ class OpenAICompletion(BaseLLM):
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"id": call_id,
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
"index": tool_call.index if tool_call.index is not None else 0,
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
@@ -521,23 +519,6 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
# Emit tool call streaming event
|
||||
tool_call_event_data = {
|
||||
"id": tool_calls[call_id]["id"],
|
||||
"function": {
|
||||
"name": tool_calls[call_id]["name"],
|
||||
"arguments": tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
},
|
||||
"type": "function",
|
||||
"index": tool_calls[call_id]["index"],
|
||||
}
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=tool_call.function.arguments if tool_call.function and tool_call.function.arguments else "",
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
tool_call=tool_call_event_data,
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, TypeVar, overload
|
||||
|
||||
from crewai.project.utils import memoize
|
||||
@@ -156,6 +158,23 @@ def cache_handler(meth: Callable[P, R]) -> CacheHandlerMethod[P, R]:
|
||||
return CacheHandlerMethod(memoize(meth))
|
||||
|
||||
|
||||
def _call_method(method: Callable[..., Any], *args: Any, **kwargs: Any) -> Any:
|
||||
"""Call a method, awaiting it if async and running in an event loop."""
|
||||
result = method(*args, **kwargs)
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
@overload
|
||||
def crew(
|
||||
meth: Callable[Concatenate[SelfT, P], Crew],
|
||||
@@ -198,7 +217,7 @@ def crew(
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = task_method(self)
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
agent_instance = getattr(task_instance, "agent", None)
|
||||
if agent_instance and agent_instance.role not in agent_roles:
|
||||
@@ -207,7 +226,7 @@ def crew(
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = agent_method(self)
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
@@ -215,7 +234,7 @@ def crew(
|
||||
self.agents = instantiated_agents
|
||||
self.tasks = instantiated_tasks
|
||||
|
||||
crew_instance = meth(self, *args, **kwargs)
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
def callback_wrapper(
|
||||
hook: Callable[Concatenate[CrewInstance, P2], R2], instance: CrewInstance
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
"""Utility functions for the crewai project module."""
|
||||
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Callable, Coroutine
|
||||
from functools import wraps
|
||||
import inspect
|
||||
from typing import Any, ParamSpec, TypeVar, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
@@ -37,8 +38,8 @@ def _make_hashable(arg: Any) -> Any:
|
||||
def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a method by caching its results based on arguments.
|
||||
|
||||
Handles Pydantic BaseModel instances by converting them to JSON strings
|
||||
before hashing for cache lookup.
|
||||
Handles both sync and async methods. Pydantic BaseModel instances are
|
||||
converted to JSON strings before hashing for cache lookup.
|
||||
|
||||
Args:
|
||||
meth: The method to memoize.
|
||||
@@ -46,18 +47,16 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
Returns:
|
||||
A memoized version of the method that caches results.
|
||||
"""
|
||||
if inspect.iscoroutinefunction(meth):
|
||||
return cast(Callable[P, R], _memoize_async(meth))
|
||||
return _memoize_sync(meth)
|
||||
|
||||
|
||||
def _memoize_sync(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
"""Memoize a synchronous method."""
|
||||
|
||||
@wraps(meth)
|
||||
def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Wrapper that converts arguments to hashable form before caching.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to the memoized method.
|
||||
**kwargs: Keyword arguments to the memoized method.
|
||||
|
||||
Returns:
|
||||
The result of the memoized method call.
|
||||
"""
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
@@ -73,3 +72,27 @@ def memoize(meth: Callable[P, R]) -> Callable[P, R]:
|
||||
return result
|
||||
|
||||
return cast(Callable[P, R], wrapper)
|
||||
|
||||
|
||||
def _memoize_async(
|
||||
meth: Callable[P, Coroutine[Any, Any, R]],
|
||||
) -> Callable[P, Coroutine[Any, Any, R]]:
|
||||
"""Memoize an async method."""
|
||||
|
||||
@wraps(meth)
|
||||
async def wrapper(*args: P.args, **kwargs: P.kwargs) -> R:
|
||||
hashable_args = tuple(_make_hashable(arg) for arg in args)
|
||||
hashable_kwargs = tuple(
|
||||
sorted((k, _make_hashable(v)) for k, v in kwargs.items())
|
||||
)
|
||||
cache_key = str((hashable_args, hashable_kwargs))
|
||||
|
||||
cached_result: R | None = cache.read(tool=meth.__name__, input=cache_key)
|
||||
if cached_result is not None:
|
||||
return cached_result
|
||||
|
||||
result = await meth(*args, **kwargs)
|
||||
cache.add(tool=meth.__name__, input=cache_key, output=result)
|
||||
return result
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -2,8 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from functools import partial
|
||||
import inspect
|
||||
from pathlib import Path
|
||||
from typing import (
|
||||
TYPE_CHECKING,
|
||||
@@ -132,6 +134,22 @@ class CrewClass(Protocol):
|
||||
crew: Callable[..., Crew]
|
||||
|
||||
|
||||
def _resolve_result(result: Any) -> Any:
|
||||
"""Resolve a potentially async result to its value."""
|
||||
if inspect.iscoroutine(result):
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = None
|
||||
if loop and loop.is_running():
|
||||
import concurrent.futures
|
||||
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return pool.submit(asyncio.run, result).result()
|
||||
return asyncio.run(result)
|
||||
return result
|
||||
|
||||
|
||||
class DecoratedMethod(Generic[P, R]):
|
||||
"""Base wrapper for methods with decorator metadata.
|
||||
|
||||
@@ -162,7 +180,12 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
bound = partial(self._meth, obj)
|
||||
inner = partial(self._meth, obj)
|
||||
|
||||
def _bound(*args: Any, **kwargs: Any) -> R:
|
||||
result: R = _resolve_result(inner(*args, **kwargs)) # type: ignore[call-arg]
|
||||
return result
|
||||
|
||||
for attr in (
|
||||
"is_agent",
|
||||
"is_llm",
|
||||
@@ -174,8 +197,8 @@ class DecoratedMethod(Generic[P, R]):
|
||||
"is_crew",
|
||||
):
|
||||
if hasattr(self, attr):
|
||||
setattr(bound, attr, getattr(self, attr))
|
||||
return bound
|
||||
setattr(_bound, attr, getattr(self, attr))
|
||||
return _bound
|
||||
|
||||
def __call__(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Call the wrapped method.
|
||||
@@ -236,6 +259,7 @@ class BoundTaskMethod(Generic[TaskResultT]):
|
||||
The task result with name ensured.
|
||||
"""
|
||||
result = self._task_method.unwrap()(self._obj, *args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self._task_method.ensure_task_name(result)
|
||||
|
||||
|
||||
@@ -292,7 +316,9 @@ class TaskMethod(Generic[P, TaskResultT]):
|
||||
Returns:
|
||||
The task instance with name set if not already provided.
|
||||
"""
|
||||
return self.ensure_task_name(self._meth(*args, **kwargs))
|
||||
result = self._meth(*args, **kwargs)
|
||||
result = _resolve_result(result)
|
||||
return self.ensure_task_name(result)
|
||||
|
||||
def unwrap(self) -> Callable[P, TaskResultT]:
|
||||
"""Get the original unwrapped method.
|
||||
|
||||
@@ -243,7 +243,11 @@ def test_validate_call_params_not_supported():
|
||||
|
||||
# Patch supports_response_schema to simulate an unsupported model.
|
||||
with patch("crewai.llm.supports_response_schema", return_value=False):
|
||||
llm = LLM(model="gemini/gemini-1.5-pro", response_format=DummyResponse, is_litellm=True)
|
||||
llm = LLM(
|
||||
model="gemini/gemini-1.5-pro",
|
||||
response_format=DummyResponse,
|
||||
is_litellm=True,
|
||||
)
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
llm._validate_call_params()
|
||||
assert "does not support response_format" in str(excinfo.value)
|
||||
@@ -702,13 +706,16 @@ def test_ollama_does_not_modify_when_last_is_user(ollama_llm):
|
||||
|
||||
assert formatted == original_messages
|
||||
|
||||
|
||||
def test_native_provider_raises_error_when_supported_but_fails():
|
||||
"""Test that when a native provider is in SUPPORTED_NATIVE_PROVIDERS but fails to instantiate, we raise the error."""
|
||||
with patch("crewai.llm.SUPPORTED_NATIVE_PROVIDERS", ["openai"]):
|
||||
with patch("crewai.llm.LLM._get_native_provider") as mock_get_native:
|
||||
# Mock that provider exists but throws an error when instantiated
|
||||
mock_provider = MagicMock()
|
||||
mock_provider.side_effect = ValueError("Native provider initialization failed")
|
||||
mock_provider.side_effect = ValueError(
|
||||
"Native provider initialization failed"
|
||||
)
|
||||
mock_get_native.return_value = mock_provider
|
||||
|
||||
with pytest.raises(ImportError) as excinfo:
|
||||
@@ -751,16 +758,16 @@ def test_prefixed_models_with_valid_constants_use_native_sdk():
|
||||
|
||||
|
||||
def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants."""
|
||||
"""Test that models with native provider prefixes use LiteLLM when model is NOT in constants and does NOT match patterns."""
|
||||
# Test openai/ prefix with non-OpenAI model (not in OPENAI_MODELS) → LiteLLM
|
||||
llm = LLM(model="openai/gemini-2.5-flash", is_litellm=False)
|
||||
assert llm.is_litellm is True
|
||||
assert llm.model == "openai/gemini-2.5-flash"
|
||||
|
||||
# Test openai/ prefix with unknown future model → LiteLLM
|
||||
llm2 = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
# Test openai/ prefix with model that doesn't match patterns (e.g. no gpt- prefix) → LiteLLM
|
||||
llm2 = LLM(model="openai/custom-finetune-model", is_litellm=False)
|
||||
assert llm2.is_litellm is True
|
||||
assert llm2.model == "openai/gpt-future-6"
|
||||
assert llm2.model == "openai/custom-finetune-model"
|
||||
|
||||
# Test anthropic/ prefix with non-Anthropic model → LiteLLM
|
||||
llm3 = LLM(model="anthropic/gpt-4o", is_litellm=False)
|
||||
@@ -768,6 +775,21 @@ def test_prefixed_models_with_invalid_constants_use_litellm():
|
||||
assert llm3.model == "anthropic/gpt-4o"
|
||||
|
||||
|
||||
def test_prefixed_models_with_valid_patterns_use_native_sdk():
|
||||
"""Test that models matching provider patterns use native SDK even if not in constants."""
|
||||
with patch.dict(os.environ, {"OPENAI_API_KEY": "test-key"}):
|
||||
llm = LLM(model="openai/gpt-future-6", is_litellm=False)
|
||||
assert llm.is_litellm is False
|
||||
assert llm.provider == "openai"
|
||||
assert llm.model == "gpt-future-6"
|
||||
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-key"}):
|
||||
llm2 = LLM(model="anthropic/claude-future-5", is_litellm=False)
|
||||
assert llm2.is_litellm is False
|
||||
assert llm2.provider == "anthropic"
|
||||
assert llm2.model == "claude-future-5"
|
||||
|
||||
|
||||
def test_prefixed_models_with_non_native_providers_use_litellm():
|
||||
"""Test that models with non-native provider prefixes always use LiteLLM."""
|
||||
# Test groq/ prefix (not a native provider) → LiteLLM
|
||||
@@ -821,19 +843,36 @@ def test_validate_model_in_constants():
|
||||
"""Test the _validate_model_in_constants method."""
|
||||
# OpenAI models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "openai") is True
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is False
|
||||
assert LLM._validate_model_in_constants("gpt-future-6", "openai") is True
|
||||
assert LLM._validate_model_in_constants("o1-latest", "openai") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "openai") is False
|
||||
|
||||
# Anthropic models
|
||||
assert LLM._validate_model_in_constants("claude-opus-4-0", "claude") is True
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is False
|
||||
assert LLM._validate_model_in_constants("claude-future-5", "claude") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants("claude-3-5-sonnet-latest", "claude") is True
|
||||
)
|
||||
assert LLM._validate_model_in_constants("unknown-model", "claude") is False
|
||||
|
||||
# Gemini models
|
||||
assert LLM._validate_model_in_constants("gemini-2.5-pro", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is False
|
||||
assert LLM._validate_model_in_constants("gemini-future", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("gemma-3-latest", "gemini") is True
|
||||
assert LLM._validate_model_in_constants("unknown-model", "gemini") is False
|
||||
|
||||
# Azure models
|
||||
assert LLM._validate_model_in_constants("gpt-4o", "azure") is True
|
||||
assert LLM._validate_model_in_constants("gpt-35-turbo", "azure") is True
|
||||
|
||||
# Bedrock models
|
||||
assert LLM._validate_model_in_constants("anthropic.claude-opus-4-1-20250805-v1:0", "bedrock") is True
|
||||
assert (
|
||||
LLM._validate_model_in_constants(
|
||||
"anthropic.claude-opus-4-1-20250805-v1:0", "bedrock"
|
||||
)
|
||||
is True
|
||||
)
|
||||
assert (
|
||||
LLM._validate_model_in_constants("anthropic.claude-future-v1:0", "bedrock")
|
||||
is True
|
||||
)
|
||||
|
||||
@@ -272,6 +272,99 @@ def another_simple_tool():
|
||||
return "Hi!"
|
||||
|
||||
|
||||
class TestAsyncDecoratorSupport:
|
||||
"""Tests for async method support in @agent, @task decorators."""
|
||||
|
||||
def test_async_agent_memoization(self):
|
||||
"""Async agent methods should be properly memoized."""
|
||||
|
||||
class AsyncAgentCrew:
|
||||
call_count = 0
|
||||
|
||||
@agent
|
||||
async def async_agent(self):
|
||||
AsyncAgentCrew.call_count += 1
|
||||
return Agent(
|
||||
role="Async Agent", goal="Async Goal", backstory="Async Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentCrew()
|
||||
first_call = crew.async_agent()
|
||||
second_call = crew.async_agent()
|
||||
|
||||
assert first_call is second_call, "Async agent memoization failed"
|
||||
assert AsyncAgentCrew.call_count == 1, "Async agent called more than once"
|
||||
|
||||
def test_async_task_memoization(self):
|
||||
"""Async task methods should be properly memoized."""
|
||||
|
||||
class AsyncTaskCrew:
|
||||
call_count = 0
|
||||
|
||||
@task
|
||||
async def async_task(self):
|
||||
AsyncTaskCrew.call_count += 1
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskCrew()
|
||||
first_call = crew.async_task()
|
||||
second_call = crew.async_task()
|
||||
|
||||
assert first_call is second_call, "Async task memoization failed"
|
||||
assert AsyncTaskCrew.call_count == 1, "Async task called more than once"
|
||||
|
||||
def test_async_task_name_inference(self):
|
||||
"""Async task should have name inferred from method name."""
|
||||
|
||||
class AsyncTaskNameCrew:
|
||||
@task
|
||||
async def my_async_task(self):
|
||||
return Task(
|
||||
description="Async Description", expected_output="Async Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskNameCrew()
|
||||
task_instance = crew.my_async_task()
|
||||
|
||||
assert task_instance.name == "my_async_task", (
|
||||
"Async task name not inferred correctly"
|
||||
)
|
||||
|
||||
def test_async_agent_returns_agent_not_coroutine(self):
|
||||
"""Async agent decorator should return Agent, not coroutine."""
|
||||
|
||||
class AsyncAgentTypeCrew:
|
||||
@agent
|
||||
async def typed_async_agent(self):
|
||||
return Agent(
|
||||
role="Typed Agent", goal="Typed Goal", backstory="Typed Backstory"
|
||||
)
|
||||
|
||||
crew = AsyncAgentTypeCrew()
|
||||
result = crew.typed_async_agent()
|
||||
|
||||
assert isinstance(result, Agent), (
|
||||
f"Expected Agent, got {type(result).__name__}"
|
||||
)
|
||||
|
||||
def test_async_task_returns_task_not_coroutine(self):
|
||||
"""Async task decorator should return Task, not coroutine."""
|
||||
|
||||
class AsyncTaskTypeCrew:
|
||||
@task
|
||||
async def typed_async_task(self):
|
||||
return Task(
|
||||
description="Typed Description", expected_output="Typed Output"
|
||||
)
|
||||
|
||||
crew = AsyncTaskTypeCrew()
|
||||
result = crew.typed_async_task()
|
||||
|
||||
assert isinstance(result, Task), f"Expected Task, got {type(result).__name__}"
|
||||
|
||||
|
||||
def test_internal_crew_with_mcp():
|
||||
from crewai_tools.adapters.tool_collection import ToolCollection
|
||||
|
||||
|
||||
@@ -715,243 +715,3 @@ class TestStreamingImports:
|
||||
assert StreamChunk is not None
|
||||
assert StreamChunkType is not None
|
||||
assert ToolCallChunk is not None
|
||||
|
||||
|
||||
class TestLLMStreamChunkEventToolCall:
|
||||
"""Tests for LLMStreamChunkEvent with tool call information."""
|
||||
|
||||
def test_llm_stream_chunk_event_with_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call data."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
|
||||
# Create a tool call event
|
||||
tool_call = ToolCall(
|
||||
id="call-123",
|
||||
function=FunctionCall(
|
||||
name="search",
|
||||
arguments='{"query": "test"}',
|
||||
),
|
||||
type="function",
|
||||
index=0,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"query": "test"}',
|
||||
tool_call=tool_call,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"query": "test"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-123"
|
||||
assert event.tool_call.function.name == "search"
|
||||
assert event.tool_call.function.arguments == '{"query": "test"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_with_dict_tool_call(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent correctly handles tool call as dict."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
# Create a tool call event using dict (as providers emit)
|
||||
tool_call_dict = {
|
||||
"id": "call-456",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"arguments": '{"location": "NYC"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 1,
|
||||
}
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"location": "NYC"}',
|
||||
tool_call=tool_call_dict,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == '{"location": "NYC"}'
|
||||
assert event.tool_call is not None
|
||||
assert event.tool_call.id == "call-456"
|
||||
assert event.tool_call.function.name == "get_weather"
|
||||
assert event.tool_call.function.arguments == '{"location": "NYC"}'
|
||||
assert event.call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
def test_llm_stream_chunk_event_text_only(self) -> None:
|
||||
"""Test that LLMStreamChunkEvent works for text-only chunks."""
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
)
|
||||
|
||||
assert event.chunk == "Hello, world!"
|
||||
assert event.tool_call is None
|
||||
assert event.call_type == LLMCallType.LLM_CALL
|
||||
|
||||
|
||||
class TestBaseLLMEmitStreamChunkEvent:
|
||||
"""Tests for BaseLLM._emit_stream_chunk_event method."""
|
||||
|
||||
def test_emit_stream_chunk_event_infers_tool_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers TOOL_CALL type when tool_call is present."""
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with tool_call - should infer TOOL_CALL type
|
||||
tool_call_dict = {
|
||||
"id": "call-789",
|
||||
"function": {
|
||||
"name": "test_tool",
|
||||
"arguments": '{"arg": "value"}',
|
||||
},
|
||||
"type": "function",
|
||||
"index": 0,
|
||||
}
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk='{"arg": "value"}',
|
||||
tool_call=tool_call_dict,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
assert captured_events[0].tool_call is not None
|
||||
|
||||
def test_emit_stream_chunk_event_infers_llm_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event infers LLM_CALL type when tool_call is None."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit without tool_call - should infer LLM_CALL type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="Hello, world!",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.LLM_CALL
|
||||
assert captured_events[0].tool_call is None
|
||||
|
||||
def test_emit_stream_chunk_event_respects_explicit_call_type(self) -> None:
|
||||
"""Test that _emit_stream_chunk_event respects explicitly provided call_type."""
|
||||
from unittest.mock import patch
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.events.types.llm_events import LLMCallType, LLMStreamChunkEvent
|
||||
|
||||
# Create a mock BaseLLM instance
|
||||
with patch.object(BaseLLM, "__abstractmethods__", set()):
|
||||
llm = BaseLLM(model="test-model") # type: ignore
|
||||
|
||||
captured_events: list[LLMStreamChunkEvent] = []
|
||||
|
||||
def capture_emit(source: Any, event: Any) -> None:
|
||||
if isinstance(event, LLMStreamChunkEvent):
|
||||
captured_events.append(event)
|
||||
|
||||
with patch("crewai.llms.base_llm.crewai_event_bus") as mock_bus:
|
||||
mock_bus.emit = capture_emit
|
||||
|
||||
# Emit with explicit call_type - should use provided type
|
||||
llm._emit_stream_chunk_event(
|
||||
chunk="test",
|
||||
tool_call=None,
|
||||
call_type=LLMCallType.TOOL_CALL, # Explicitly set even though no tool_call
|
||||
)
|
||||
|
||||
assert len(captured_events) == 1
|
||||
assert captured_events[0].call_type == LLMCallType.TOOL_CALL
|
||||
|
||||
|
||||
class TestStreamingToolCallExtraction:
|
||||
"""Tests for tool call extraction from streaming events."""
|
||||
|
||||
def test_extract_tool_call_info_from_event(self) -> None:
|
||||
"""Test that tool call info is correctly extracted from LLMStreamChunkEvent."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import (
|
||||
LLMStreamChunkEvent,
|
||||
ToolCall,
|
||||
FunctionCall,
|
||||
)
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
# Create event with tool call
|
||||
tool_call = ToolCall(
|
||||
id="call-extract-test",
|
||||
function=FunctionCall(
|
||||
name="extract_test",
|
||||
arguments='{"key": "value"}',
|
||||
),
|
||||
type="function",
|
||||
index=2,
|
||||
)
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk='{"key": "value"}',
|
||||
tool_call=tool_call,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TOOL_CALL
|
||||
assert tool_call_chunk is not None
|
||||
assert tool_call_chunk.tool_id == "call-extract-test"
|
||||
assert tool_call_chunk.tool_name == "extract_test"
|
||||
assert tool_call_chunk.arguments == '{"key": "value"}'
|
||||
assert tool_call_chunk.index == 2
|
||||
|
||||
def test_extract_tool_call_info_returns_text_for_no_tool_call(self) -> None:
|
||||
"""Test that TEXT type is returned when no tool call is present."""
|
||||
from crewai.utilities.streaming import _extract_tool_call_info
|
||||
from crewai.events.types.llm_events import LLMStreamChunkEvent
|
||||
from crewai.types.streaming import StreamChunkType
|
||||
|
||||
event = LLMStreamChunkEvent(
|
||||
chunk="Just text content",
|
||||
tool_call=None,
|
||||
)
|
||||
|
||||
chunk_type, tool_call_chunk = _extract_tool_call_info(event)
|
||||
|
||||
assert chunk_type == StreamChunkType.TEXT
|
||||
assert tool_call_chunk is None
|
||||
|
||||
@@ -1,3 +1,3 @@
|
||||
"""CrewAI development tools."""
|
||||
|
||||
__version__ = "1.6.0"
|
||||
__version__ = "1.6.1"
|
||||
|
||||
Reference in New Issue
Block a user