mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 01:32:36 +00:00
Compare commits
7 Commits
1.14.4
...
feat/llm-t
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5321aae92 | ||
|
|
54f5b7db2e | ||
|
|
c9a6955cd6 | ||
|
|
086f534d4e | ||
|
|
fe93dfe64c | ||
|
|
5837f8edb8 | ||
|
|
cdc4b43620 |
@@ -5,6 +5,7 @@ from crewai_tools.tools.daytona_sandbox_tool.daytona_python_tool import (
|
||||
DaytonaPythonTool,
|
||||
)
|
||||
|
||||
|
||||
__all__ = [
|
||||
"DaytonaBaseTool",
|
||||
"DaytonaExecTool",
|
||||
|
||||
@@ -84,7 +84,7 @@ voyageai = [
|
||||
"voyageai~=0.3.5",
|
||||
]
|
||||
litellm = [
|
||||
"litellm~=1.83.0",
|
||||
"litellm~=1.83.7",
|
||||
]
|
||||
bedrock = [
|
||||
"boto3~=1.42.79",
|
||||
|
||||
@@ -13,6 +13,7 @@ from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.llm import LLM
|
||||
from crewai.llm_result import LLMResult, ToolCallRecord
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.process import Process
|
||||
from crewai.state.checkpoint_config import CheckpointConfig # noqa: F401
|
||||
@@ -195,11 +196,13 @@ __all__ = [
|
||||
"Flow",
|
||||
"Knowledge",
|
||||
"LLMGuardrail",
|
||||
"LLMResult",
|
||||
"Memory",
|
||||
"PlanningConfig",
|
||||
"Process",
|
||||
"RuntimeState",
|
||||
"Task",
|
||||
"TaskOutput",
|
||||
"ToolCallRecord",
|
||||
"__version__",
|
||||
]
|
||||
|
||||
@@ -32,6 +32,11 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llm_result import (
|
||||
LLMResult,
|
||||
ToolCallRecord,
|
||||
estimate_cost_usd as _estimate_cost_usd,
|
||||
)
|
||||
from crewai.llms.base_llm import (
|
||||
BaseLLM,
|
||||
JsonResponseFormat,
|
||||
@@ -1699,6 +1704,7 @@ class LLM(BaseLLM):
|
||||
from_task: Task | None = None,
|
||||
from_agent: BaseAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
max_iterations: int = 10,
|
||||
) -> str | Any:
|
||||
"""High-level LLM call method.
|
||||
|
||||
@@ -1716,16 +1722,250 @@ class LLM(BaseLLM):
|
||||
from_task: Optional Task that invoked the LLM
|
||||
from_agent: Optional Agent that invoked the LLM
|
||||
response_model: Optional Model that contains a pydantic response model.
|
||||
max_iterations: Maximum number of tool-loop iterations (default 10).
|
||||
Only used when both ``tools`` and ``available_functions``
|
||||
are provided.
|
||||
|
||||
Returns:
|
||||
Union[str, Any]: Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
Union[str, LLMResult, Any]:
|
||||
- ``str`` when called without tools (backwards compatible).
|
||||
- ``LLMResult`` when called with tools and available_functions.
|
||||
- ``Any`` when a tool call returns a non-string result in legacy mode.
|
||||
|
||||
Raises:
|
||||
TypeError: If messages format is invalid
|
||||
ValueError: If response format is not supported
|
||||
LLMContextLengthExceededError: If input exceeds model's context limit
|
||||
"""
|
||||
# When tools AND available_functions are both provided, use the tool loop
|
||||
# which returns an LLMResult with structured metadata.
|
||||
if tools and available_functions:
|
||||
return self._call_with_tool_loop(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
max_iterations=max_iterations,
|
||||
)
|
||||
|
||||
# Original single-shot path — returns str (backwards compatible).
|
||||
return self._call_single(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def _call_with_tool_loop(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]],
|
||||
callbacks: list[Any] | None,
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Task | None,
|
||||
from_agent: BaseAgent | None,
|
||||
response_model: type[BaseModel] | None,
|
||||
max_iterations: int,
|
||||
) -> LLMResult:
|
||||
"""Run an LLM tool loop, returning a structured LLMResult.
|
||||
|
||||
Keeps calling the model until it stops requesting tool calls or
|
||||
``max_iterations`` is reached.
|
||||
"""
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
# Work on a mutable copy so we can append assistant/tool messages.
|
||||
conversation: list[dict[str, Any]] = list(messages) # type: ignore[arg-type]
|
||||
|
||||
result = LLMResult(
|
||||
text="",
|
||||
tool_calls=[],
|
||||
usage=UsageMetrics(),
|
||||
cost_usd=0.0,
|
||||
iterations=0,
|
||||
)
|
||||
|
||||
for iteration in range(max_iterations):
|
||||
# Call the model WITHOUT available_functions so the internal
|
||||
# handler returns tool_calls as-is instead of executing them.
|
||||
raw = self._call_single(
|
||||
messages=conversation, # type: ignore[arg-type]
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=None, # Don't let inner layer execute
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
result.iterations = iteration + 1
|
||||
|
||||
# Accumulate usage from this iteration
|
||||
self._accumulate_usage(result)
|
||||
|
||||
# If we got a string back, the model is done (no tool calls).
|
||||
if isinstance(raw, str):
|
||||
result.text = raw
|
||||
break
|
||||
|
||||
# If we got tool_calls (list), execute them and feed results back.
|
||||
if isinstance(raw, list):
|
||||
# Append assistant message with tool calls to conversation
|
||||
assistant_msg: dict[str, Any] = {
|
||||
"role": "assistant",
|
||||
"content": None,
|
||||
"tool_calls": [
|
||||
{
|
||||
"id": getattr(tc, "id", f"call_{i}"),
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": getattr(tc.function, "name", "")
|
||||
if hasattr(tc, "function")
|
||||
else "",
|
||||
"arguments": getattr(tc.function, "arguments", "{}")
|
||||
if hasattr(tc, "function")
|
||||
else "{}",
|
||||
},
|
||||
}
|
||||
for i, tc in enumerate(raw)
|
||||
],
|
||||
}
|
||||
conversation.append(assistant_msg)
|
||||
|
||||
# Execute each tool call
|
||||
for tc in raw:
|
||||
func_name = sanitize_tool_name(
|
||||
getattr(tc.function, "name", "")
|
||||
if hasattr(tc, "function")
|
||||
else ""
|
||||
)
|
||||
func_args_str = (
|
||||
getattr(tc.function, "arguments", "{}")
|
||||
if hasattr(tc, "function")
|
||||
else "{}"
|
||||
)
|
||||
tool_call_id = getattr(tc, "id", f"call_{func_name}")
|
||||
|
||||
try:
|
||||
func_args = json.loads(func_args_str)
|
||||
except (json.JSONDecodeError, TypeError):
|
||||
func_args = {}
|
||||
|
||||
record = ToolCallRecord(
|
||||
name=func_name,
|
||||
input=func_args,
|
||||
)
|
||||
|
||||
if func_name in available_functions:
|
||||
t0 = datetime.now()
|
||||
started_at = t0
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=func_name,
|
||||
tool_args=func_args,
|
||||
from_agent=from_agent,
|
||||
from_task=from_task,
|
||||
),
|
||||
)
|
||||
try:
|
||||
fn = available_functions[func_name]
|
||||
tool_output = fn(**func_args)
|
||||
t1 = datetime.now()
|
||||
record.output = (
|
||||
str(tool_output) if tool_output is not None else ""
|
||||
)
|
||||
record.duration_ms = (t1 - t0).total_seconds() * 1000
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=tool_output,
|
||||
tool_name=func_name,
|
||||
tool_args=func_args,
|
||||
started_at=started_at,
|
||||
finished_at=t1,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
t1 = datetime.now()
|
||||
record.output = f"Error: {e}"
|
||||
record.duration_ms = (t1 - t0).total_seconds() * 1000
|
||||
record.is_error = True
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=func_name,
|
||||
tool_args=func_args,
|
||||
error=str(e),
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
),
|
||||
)
|
||||
else:
|
||||
record.output = f"Error: unknown function '{func_name}'"
|
||||
record.is_error = True
|
||||
|
||||
result.tool_calls.append(record)
|
||||
|
||||
# Append tool result message for the model
|
||||
conversation.append(
|
||||
{
|
||||
"role": "tool",
|
||||
"tool_call_id": tool_call_id,
|
||||
"content": record.output,
|
||||
}
|
||||
)
|
||||
else:
|
||||
# Unexpected return type — treat as final text
|
||||
result.text = str(raw)
|
||||
break
|
||||
else:
|
||||
# max_iterations exhausted — use last text or empty
|
||||
if not result.text and result.tool_calls:
|
||||
result.text = (
|
||||
f"Max iterations ({max_iterations}) reached. "
|
||||
f"Last tool: {result.tool_calls[-1].name}"
|
||||
)
|
||||
|
||||
# Estimate cost
|
||||
result.cost_usd = _estimate_cost_usd(
|
||||
self.model,
|
||||
result.usage.prompt_tokens,
|
||||
result.usage.completion_tokens,
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _accumulate_usage(self, result: LLMResult) -> None:
|
||||
"""Pull token counts from the internal tracker into the LLMResult."""
|
||||
tracker = getattr(self, "_token_usage", None)
|
||||
if tracker and isinstance(tracker, dict):
|
||||
result.usage.prompt_tokens = tracker.get("prompt_tokens", 0)
|
||||
result.usage.completion_tokens = tracker.get("completion_tokens", 0)
|
||||
result.usage.total_tokens = tracker.get("total_tokens", 0)
|
||||
result.usage.successful_requests += 1
|
||||
|
||||
def _call_single(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Task | None = None,
|
||||
from_agent: BaseAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Single-shot LLM call (original call() logic)."""
|
||||
with llm_call_context() as call_id:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1819,7 +2059,7 @@ class LLM(BaseLLM):
|
||||
|
||||
logging.info("Retrying LLM call without the unsupported 'stop'")
|
||||
|
||||
return self.call(
|
||||
return self._call_single(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
|
||||
112
lib/crewai/src/crewai/llm_result.py
Normal file
112
lib/crewai/src/crewai/llm_result.py
Normal file
@@ -0,0 +1,112 @@
|
||||
"""Structured result types for LLM.call() with tool loop support.
|
||||
|
||||
When LLM.call() is invoked with tools and available_functions, it returns
|
||||
an LLMResult instead of a plain string. This preserves backwards compatibility:
|
||||
calls without tools still return str.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import Any
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
class ToolCallRecord(BaseModel):
|
||||
"""Record of a single tool call executed during an LLM tool loop.
|
||||
|
||||
Attributes:
|
||||
name: The tool function name.
|
||||
input: The arguments passed to the tool.
|
||||
output: The string result returned by the tool.
|
||||
duration_ms: Wall-clock time for the tool execution in milliseconds.
|
||||
is_error: Whether the tool call raised an exception.
|
||||
"""
|
||||
|
||||
name: str
|
||||
input: dict[str, Any] = Field(default_factory=dict)
|
||||
output: str = ""
|
||||
duration_ms: float = 0.0
|
||||
is_error: bool = False
|
||||
|
||||
|
||||
class LLMResult(BaseModel):
|
||||
"""Structured result from LLM.call() when tools are used.
|
||||
|
||||
Attributes:
|
||||
text: The final text response from the model.
|
||||
tool_calls: Ordered list of every tool call made during the loop.
|
||||
usage: Aggregated token usage across all iterations.
|
||||
cost_usd: Estimated cost in USD based on model pricing.
|
||||
iterations: Number of LLM round-trips in the tool loop.
|
||||
"""
|
||||
|
||||
text: str = ""
|
||||
tool_calls: list[ToolCallRecord] = Field(default_factory=list)
|
||||
usage: UsageMetrics = Field(default_factory=UsageMetrics)
|
||||
cost_usd: float = 0.0
|
||||
iterations: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Simple cost estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
# USD per 1M tokens. Covers major models. Inspired by Iris's pricing table.
|
||||
PRICING: dict[str, dict[str, float]] = {
|
||||
# Anthropic
|
||||
"claude-opus-4-7": {"in": 5.00, "out": 25.00},
|
||||
"claude-sonnet-4-6": {"in": 3.00, "out": 15.00},
|
||||
"claude-sonnet-4-5": {"in": 3.00, "out": 15.00},
|
||||
"claude-haiku-4-5": {"in": 1.00, "out": 5.00},
|
||||
# OpenAI
|
||||
"gpt-4o": {"in": 2.50, "out": 10.00},
|
||||
"gpt-4o-mini": {"in": 0.15, "out": 0.60},
|
||||
"gpt-4.1": {"in": 2.00, "out": 8.00},
|
||||
"gpt-4.1-mini": {"in": 0.40, "out": 1.60},
|
||||
"gpt-4.1-nano": {"in": 0.10, "out": 0.40},
|
||||
"o1": {"in": 15.00, "out": 60.00},
|
||||
"o1-mini": {"in": 3.00, "out": 12.00},
|
||||
"o3": {"in": 2.00, "out": 8.00},
|
||||
"o3-mini": {"in": 1.10, "out": 4.40},
|
||||
"gpt-5": {"in": 1.25, "out": 10.00},
|
||||
# Google Gemini
|
||||
"gemini-2.5-pro": {"in": 1.25, "out": 10.00},
|
||||
"gemini-2.5-flash": {"in": 0.30, "out": 2.50},
|
||||
"gemini-2.0-flash": {"in": 0.10, "out": 0.40},
|
||||
}
|
||||
|
||||
|
||||
def _lookup_pricing(model: str) -> dict[str, float] | None:
|
||||
"""Resolve a model name to its pricing row.
|
||||
|
||||
Handles provider prefixes (``anthropic/claude-sonnet-4-6``) and partial
|
||||
matches (``claude-sonnet-4-6-20250514`` → ``claude-sonnet-4-6``).
|
||||
"""
|
||||
if not model:
|
||||
return None
|
||||
# Exact match
|
||||
if model in PRICING:
|
||||
return PRICING[model]
|
||||
# Strip provider prefix
|
||||
if "/" in model:
|
||||
suffix = model.rsplit("/", 1)[1]
|
||||
if suffix in PRICING:
|
||||
return PRICING[suffix]
|
||||
model = suffix
|
||||
# Prefix / partial match
|
||||
for key in PRICING:
|
||||
if model.startswith(key) or key.startswith(model):
|
||||
return PRICING[key]
|
||||
return None
|
||||
|
||||
|
||||
def estimate_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float:
|
||||
"""Estimate the cost in USD for a given model and token counts."""
|
||||
pricing = _lookup_pricing(model)
|
||||
if not pricing:
|
||||
return 0.0
|
||||
return (
|
||||
prompt_tokens * pricing["in"] + completion_tokens * pricing["out"]
|
||||
) / 1_000_000
|
||||
@@ -27,6 +27,7 @@ from crewai.mcp.filters import (
|
||||
create_static_tool_filter,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.mcp.client import MCPClient
|
||||
from crewai.mcp.tool_resolver import MCPToolResolver
|
||||
|
||||
411
lib/crewai/tests/test_llm_tool_loop.py
Normal file
411
lib/crewai/tests/test_llm_tool_loop.py
Normal file
@@ -0,0 +1,411 @@
|
||||
"""Tests for LLM.call() tool loop and LLMResult.
|
||||
|
||||
All LLM calls are mocked — no real API traffic.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from types import SimpleNamespace
|
||||
from typing import Any
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.llm_result import (
|
||||
LLMResult,
|
||||
ToolCallRecord,
|
||||
_lookup_pricing,
|
||||
estimate_cost_usd,
|
||||
)
|
||||
|
||||
|
||||
def _make_litellm_llm(model: str = "gpt-4o") -> Any:
|
||||
"""Create an LLM instance that uses the litellm fallback path."""
|
||||
from crewai.llm import LLM
|
||||
return LLM(model=model, is_litellm=True)
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
def _make_tool_call(name: str, arguments: dict, call_id: str = "call_1"):
|
||||
"""Build a tool-call object using litellm's actual types."""
|
||||
try:
|
||||
from litellm.types.utils import (
|
||||
ChatCompletionMessageToolCall,
|
||||
Function,
|
||||
)
|
||||
return ChatCompletionMessageToolCall(
|
||||
id=call_id,
|
||||
function=Function(name=name, arguments=json.dumps(arguments)),
|
||||
type="function",
|
||||
)
|
||||
except ImportError:
|
||||
func = SimpleNamespace(name=name, arguments=json.dumps(arguments))
|
||||
return SimpleNamespace(id=call_id, function=func, type="function")
|
||||
|
||||
|
||||
def _make_model_response(content: str | None = None, tool_calls: list | None = None):
|
||||
"""Build a minimal mock ModelResponse that passes isinstance checks.
|
||||
|
||||
We need it to be an instance of litellm's ModelResponse/ModelResponseBase
|
||||
so the internal isinstance() checks work. We import those types when
|
||||
litellm is available.
|
||||
"""
|
||||
try:
|
||||
from litellm.types.utils import (
|
||||
Choices,
|
||||
Message,
|
||||
ModelResponse,
|
||||
Usage,
|
||||
)
|
||||
|
||||
message = Message(content=content, tool_calls=tool_calls or None)
|
||||
choice = Choices(message=message, finish_reason="stop", index=0)
|
||||
resp = ModelResponse(
|
||||
choices=[choice],
|
||||
usage=Usage(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
),
|
||||
)
|
||||
return resp
|
||||
except ImportError:
|
||||
# Fallback to SimpleNamespace if litellm not installed
|
||||
message = SimpleNamespace(content=content, tool_calls=tool_calls or [])
|
||||
choice = SimpleNamespace(message=message, finish_reason="stop")
|
||||
usage = SimpleNamespace(
|
||||
prompt_tokens=100,
|
||||
completion_tokens=50,
|
||||
total_tokens=150,
|
||||
)
|
||||
resp = SimpleNamespace(
|
||||
choices=[choice],
|
||||
model_extra={"usage": usage},
|
||||
)
|
||||
return resp
|
||||
|
||||
|
||||
DUMMY_TOOL_SCHEMA = [
|
||||
{
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": "get_weather",
|
||||
"description": "Get weather for a city",
|
||||
"parameters": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"city": {"type": "string"},
|
||||
},
|
||||
"required": ["city"],
|
||||
},
|
||||
},
|
||||
}
|
||||
]
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Unit tests for LLMResult / ToolCallRecord
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestLLMResultModels:
|
||||
def test_tool_call_record_defaults(self):
|
||||
r = ToolCallRecord(name="foo")
|
||||
assert r.input == {}
|
||||
assert r.output == ""
|
||||
assert r.duration_ms == 0.0
|
||||
assert r.is_error is False
|
||||
|
||||
def test_llm_result_defaults(self):
|
||||
r = LLMResult()
|
||||
assert r.text == ""
|
||||
assert r.tool_calls == []
|
||||
assert r.cost_usd == 0.0
|
||||
assert r.iterations == 0
|
||||
assert r.usage.total_tokens == 0
|
||||
|
||||
def test_llm_result_with_data(self):
|
||||
r = LLMResult(
|
||||
text="hello",
|
||||
tool_calls=[ToolCallRecord(name="foo", input={"a": 1}, output="bar")],
|
||||
iterations=2,
|
||||
cost_usd=0.005,
|
||||
)
|
||||
assert r.text == "hello"
|
||||
assert len(r.tool_calls) == 1
|
||||
assert r.tool_calls[0].name == "foo"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Cost estimation
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCostEstimation:
|
||||
def test_known_model(self):
|
||||
cost = estimate_cost_usd("gpt-4o", prompt_tokens=1_000_000, completion_tokens=0)
|
||||
assert cost == pytest.approx(2.50)
|
||||
|
||||
def test_known_model_output(self):
|
||||
cost = estimate_cost_usd("gpt-4o", prompt_tokens=0, completion_tokens=1_000_000)
|
||||
assert cost == pytest.approx(10.00)
|
||||
|
||||
def test_unknown_model_returns_zero(self):
|
||||
cost = estimate_cost_usd("some-random-model-xyz", 1000, 1000)
|
||||
assert cost == 0.0
|
||||
|
||||
def test_provider_prefix_stripped(self):
|
||||
cost = estimate_cost_usd("anthropic/claude-sonnet-4-6", 1_000_000, 0)
|
||||
assert cost == pytest.approx(3.00)
|
||||
|
||||
def test_partial_match(self):
|
||||
# "claude-sonnet-4-6-20250514" should match "claude-sonnet-4-6"
|
||||
cost = estimate_cost_usd("claude-sonnet-4-6-20250514", 1_000_000, 0)
|
||||
assert cost == pytest.approx(3.00)
|
||||
|
||||
def test_lookup_none(self):
|
||||
assert _lookup_pricing("") is None
|
||||
assert _lookup_pricing("nonexistent") is None
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM.call() backwards compatibility (no tools → returns str)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCallBackwardsCompat:
|
||||
"""LLM.call() without tools must return str exactly as before."""
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_call_without_tools_returns_str(self, mock_litellm):
|
||||
"""Plain call without tools should return a string."""
|
||||
mock_litellm.completion.return_value = _make_model_response(content="Hello world")
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
result = llm.call("Say hello")
|
||||
|
||||
assert isinstance(result, str)
|
||||
assert result == "Hello world"
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# LLM.call() with tools → returns LLMResult
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
class TestCallWithToolLoop:
|
||||
"""When tools + available_functions are passed, call() returns LLMResult."""
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_single_tool_call_then_text(self, mock_litellm):
|
||||
"""Model calls one tool, then responds with text."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
# First call: model wants to call get_weather
|
||||
tool_call = _make_tool_call("get_weather", {"city": "SF"})
|
||||
resp1 = _make_model_response(content=None, tool_calls=[tool_call])
|
||||
# Second call: model responds with text
|
||||
resp2 = _make_model_response(content="It's sunny in SF!")
|
||||
mock_litellm.completion.side_effect = [resp1, resp2]
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
return f"Sunny, 72°F in {city}"
|
||||
|
||||
result = llm.call(
|
||||
messages="What's the weather in SF?",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": get_weather},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.text == "It's sunny in SF!"
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].name == "get_weather"
|
||||
assert result.tool_calls[0].input == {"city": "SF"}
|
||||
assert "Sunny" in result.tool_calls[0].output
|
||||
assert result.tool_calls[0].is_error is False
|
||||
assert result.iterations == 2
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_multiple_tool_calls_in_sequence(self, mock_litellm):
|
||||
"""Model calls two tools across two iterations."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
tc1 = _make_tool_call("get_weather", {"city": "SF"}, "call_1")
|
||||
resp1 = _make_model_response(content=None, tool_calls=[tc1])
|
||||
|
||||
tc2 = _make_tool_call("get_weather", {"city": "NYC"}, "call_2")
|
||||
resp2 = _make_model_response(content=None, tool_calls=[tc2])
|
||||
|
||||
resp3 = _make_model_response(content="SF is sunny, NYC is rainy.")
|
||||
mock_litellm.completion.side_effect = [resp1, resp2, resp3]
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
def get_weather(city: str) -> str:
|
||||
return f"Weather for {city}: fine"
|
||||
|
||||
result = llm.call(
|
||||
messages="Compare SF and NYC weather",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": get_weather},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.tool_calls) == 2
|
||||
assert result.tool_calls[0].input["city"] == "SF"
|
||||
assert result.tool_calls[1].input["city"] == "NYC"
|
||||
assert result.iterations == 3
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_max_iterations_stops_loop(self, mock_litellm):
|
||||
"""Loop stops when max_iterations is reached."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
# Model always wants to call a tool — never stops
|
||||
def make_tool_resp():
|
||||
tc = _make_tool_call("get_weather", {"city": "SF"})
|
||||
return _make_model_response(content=None, tool_calls=[tc])
|
||||
|
||||
mock_litellm.completion.side_effect = [make_tool_resp() for _ in range(5)]
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
result = llm.call(
|
||||
messages="Loop forever",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": lambda city: "sunny"},
|
||||
max_iterations=3,
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.iterations == 3
|
||||
assert len(result.tool_calls) == 3
|
||||
# Should have a text noting max iterations
|
||||
assert "Max iterations" in result.text
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_tool_error_handling(self, mock_litellm):
|
||||
"""Tool that raises an exception is captured in the record."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
tc = _make_tool_call("get_weather", {"city": "SF"})
|
||||
resp1 = _make_model_response(content=None, tool_calls=[tc])
|
||||
resp2 = _make_model_response(content="Sorry, couldn't get weather.")
|
||||
mock_litellm.completion.side_effect = [resp1, resp2]
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
def broken_weather(city: str) -> str:
|
||||
raise RuntimeError("API down")
|
||||
|
||||
result = llm.call(
|
||||
messages="Weather?",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": broken_weather},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert len(result.tool_calls) == 1
|
||||
assert result.tool_calls[0].is_error is True
|
||||
assert "API down" in result.tool_calls[0].output
|
||||
assert result.text == "Sorry, couldn't get weather."
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_unknown_function_error(self, mock_litellm):
|
||||
"""Tool call for a function not in available_functions."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
tc = _make_tool_call("nonexistent_tool", {})
|
||||
resp1 = _make_model_response(content=None, tool_calls=[tc])
|
||||
resp2 = _make_model_response(content="I couldn't find that tool.")
|
||||
mock_litellm.completion.side_effect = [resp1, resp2]
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
result = llm.call(
|
||||
messages="Do something",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": lambda city: "sunny"},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.tool_calls[0].is_error is True
|
||||
assert "unknown function" in result.tool_calls[0].output
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_cost_estimation_populated(self, mock_litellm):
|
||||
"""cost_usd is populated from token usage and model pricing."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
resp = _make_model_response(content="Done!")
|
||||
mock_litellm.completion.return_value = resp
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
result = llm.call(
|
||||
messages="Hello",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": lambda city: "sunny"},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
# cost_usd should be >= 0 (may be 0 if usage tracking didn't fire,
|
||||
# but the field should exist and be a float)
|
||||
assert isinstance(result.cost_usd, float)
|
||||
|
||||
@patch("crewai.llm.litellm")
|
||||
def test_immediate_text_response_with_tools(self, mock_litellm):
|
||||
"""Model responds with text on first call (no tool use)."""
|
||||
mock_litellm.drop_params = True
|
||||
mock_litellm.suppress_debug_info = True
|
||||
mock_litellm.success_callback = []
|
||||
mock_litellm._async_success_callback = []
|
||||
mock_litellm.callbacks = []
|
||||
|
||||
resp = _make_model_response(content="I know the answer already.")
|
||||
mock_litellm.completion.return_value = resp
|
||||
|
||||
llm = _make_litellm_llm()
|
||||
|
||||
result = llm.call(
|
||||
messages="What's 2+2?",
|
||||
tools=DUMMY_TOOL_SCHEMA,
|
||||
available_functions={"get_weather": lambda city: "sunny"},
|
||||
)
|
||||
|
||||
assert isinstance(result, LLMResult)
|
||||
assert result.text == "I know the answer already."
|
||||
assert len(result.tool_calls) == 0
|
||||
assert result.iterations == 1
|
||||
@@ -164,7 +164,7 @@ info = "Commits must follow Conventional Commits 1.0.0."
|
||||
[tool.uv]
|
||||
# Pinned to include the security patch releases (authlib 1.6.11,
|
||||
# langchain-text-splitters 1.1.2) uploaded on 2026-04-16.
|
||||
exclude-newer = "2026-04-22"
|
||||
exclude-newer = "2026-04-26"
|
||||
|
||||
# composio-core pins rich<14 but textual requires rich>=14.
|
||||
# onnxruntime 1.24+ dropped Python 3.10 wheels; cap it so qdrant[fastembed] resolves on 3.10.
|
||||
|
||||
Reference in New Issue
Block a user