Compare commits

..

1 Commits

Author SHA1 Message Date
Joao Moura
aee571b775 feat: add support for agents to invoke Flows as tools
Agents can now declare flows=[MyFlow] and invoke them as regular tools.
Each Flow class is wrapped as a FlowTool(BaseTool) — the agent decides
WHEN to use it (via tool selection), the Flow handles HOW (deterministic
execution).

- Add flows field to Agent
- Add FlowTool and create_flow_tools in tools/flow_tool.py
- Export from crewai.__init__
- 14 tests

No dependency changes. No memory changes. Just the flow-tool feature.
2026-04-27 05:10:38 -07:00
11 changed files with 294 additions and 773 deletions

View File

@@ -5,7 +5,6 @@ from crewai_tools.tools.daytona_sandbox_tool.daytona_python_tool import (
DaytonaPythonTool,
)
__all__ = [
"DaytonaBaseTool",
"DaytonaExecTool",

View File

@@ -84,7 +84,7 @@ voyageai = [
"voyageai~=0.3.5",
]
litellm = [
"litellm~=1.83.7",
"litellm~=1.83.0",
]
bedrock = [
"boto3~=1.42.79",

View File

@@ -13,7 +13,6 @@ from crewai.crews.crew_output import CrewOutput
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM
from crewai.llm_result import LLMResult, ToolCallRecord
from crewai.llms.base_llm import BaseLLM
from crewai.process import Process
from crewai.state.checkpoint_config import CheckpointConfig # noqa: F401
@@ -95,10 +94,16 @@ try:
}
from crewai.tools.base_tool import BaseTool as _BaseTool
from crewai.tools.flow_tool import (
FlowTool as _FlowTool,
create_flow_tools as _create_flow_tools,
)
from crewai.tools.structured_tool import CrewStructuredTool as _CrewStructuredTool
_base_namespace["BaseTool"] = _BaseTool
_base_namespace["CrewStructuredTool"] = _CrewStructuredTool
_base_namespace["FlowTool"] = _FlowTool
_base_namespace["create_flow_tools"] = _create_flow_tools # type: ignore[assignment]
try:
from crewai.a2a.config import (
@@ -196,13 +201,11 @@ __all__ = [
"Flow",
"Knowledge",
"LLMGuardrail",
"LLMResult",
"Memory",
"PlanningConfig",
"Process",
"RuntimeState",
"Task",
"TaskOutput",
"ToolCallRecord",
"__version__",
]

View File

@@ -85,6 +85,7 @@ from crewai.skills.loader import activate_skill, discover_skills
from crewai.skills.models import INSTRUCTIONS, Skill as SkillModel
from crewai.state.checkpoint_config import CheckpointConfig, apply_checkpoint
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.flow_tool import create_flow_tools
from crewai.types.callback import SerializableCallable
from crewai.utilities.agent_utils import (
get_tool_names,
@@ -305,6 +306,10 @@ class Agent(BaseAgent):
Can be a single A2AConfig/A2AClientConfig/A2AServerConfig, or a list of any number of A2AConfig/A2AClientConfig with a single A2AServerConfig.
""",
)
flows: list[Any] | None = Field(
default=None,
description="Flow classes that the agent can invoke as tools. Each entry is a Flow subclass (not an instance).",
)
agent_executor: CrewAgentExecutor | AgentExecutor | None = Field(
default=None, description="An instance of the CrewAgentExecutor class."
)
@@ -347,6 +352,7 @@ class Agent(BaseAgent):
)
self.set_skills()
self._set_flow_tools()
if self.reasoning and self.planning_config is None:
warnings.warn(
@@ -459,6 +465,16 @@ class Agent(BaseAgent):
self.skills = resolved if resolved else None
def _set_flow_tools(self) -> None:
"""Convert Flow classes in ``self.flows`` to tools and merge them."""
if not self.flows:
return
flow_tools = create_flow_tools(self.flows)
if flow_tools:
if self.tools is None:
self.tools = []
self.tools.extend(flow_tools)
def _is_any_available_memory(self) -> bool:
"""Check if unified memory is available (agent or crew)."""
if getattr(self, "memory", None):

View File

@@ -32,11 +32,6 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llm_result import (
LLMResult,
ToolCallRecord,
estimate_cost_usd as _estimate_cost_usd,
)
from crewai.llms.base_llm import (
BaseLLM,
JsonResponseFormat,
@@ -1704,7 +1699,6 @@ class LLM(BaseLLM):
from_task: Task | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
max_iterations: int = 10,
) -> str | Any:
"""High-level LLM call method.
@@ -1722,250 +1716,16 @@ class LLM(BaseLLM):
from_task: Optional Task that invoked the LLM
from_agent: Optional Agent that invoked the LLM
response_model: Optional Model that contains a pydantic response model.
max_iterations: Maximum number of tool-loop iterations (default 10).
Only used when both ``tools`` and ``available_functions``
are provided.
Returns:
Union[str, LLMResult, Any]:
- ``str`` when called without tools (backwards compatible).
- ``LLMResult`` when called with tools and available_functions.
- ``Any`` when a tool call returns a non-string result in legacy mode.
Union[str, Any]: Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
TypeError: If messages format is invalid
ValueError: If response format is not supported
LLMContextLengthExceededError: If input exceeds model's context limit
"""
# When tools AND available_functions are both provided, use the tool loop
# which returns an LLMResult with structured metadata.
if tools and available_functions:
return self._call_with_tool_loop(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
max_iterations=max_iterations,
)
# Original single-shot path — returns str (backwards compatible).
return self._call_single(
messages=messages,
tools=tools,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
def _call_with_tool_loop(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]],
callbacks: list[Any] | None,
available_functions: dict[str, Any],
from_task: Task | None,
from_agent: BaseAgent | None,
response_model: type[BaseModel] | None,
max_iterations: int,
) -> LLMResult:
"""Run an LLM tool loop, returning a structured LLMResult.
Keeps calling the model until it stops requesting tool calls or
``max_iterations`` is reached.
"""
from crewai.types.usage_metrics import UsageMetrics
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Work on a mutable copy so we can append assistant/tool messages.
conversation: list[dict[str, Any]] = list(messages) # type: ignore[arg-type]
result = LLMResult(
text="",
tool_calls=[],
usage=UsageMetrics(),
cost_usd=0.0,
iterations=0,
)
for iteration in range(max_iterations):
# Call the model WITHOUT available_functions so the internal
# handler returns tool_calls as-is instead of executing them.
raw = self._call_single(
messages=conversation, # type: ignore[arg-type]
tools=tools,
callbacks=callbacks,
available_functions=None, # Don't let inner layer execute
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
result.iterations = iteration + 1
# Accumulate usage from this iteration
self._accumulate_usage(result)
# If we got a string back, the model is done (no tool calls).
if isinstance(raw, str):
result.text = raw
break
# If we got tool_calls (list), execute them and feed results back.
if isinstance(raw, list):
# Append assistant message with tool calls to conversation
assistant_msg: dict[str, Any] = {
"role": "assistant",
"content": None,
"tool_calls": [
{
"id": getattr(tc, "id", f"call_{i}"),
"type": "function",
"function": {
"name": getattr(tc.function, "name", "")
if hasattr(tc, "function")
else "",
"arguments": getattr(tc.function, "arguments", "{}")
if hasattr(tc, "function")
else "{}",
},
}
for i, tc in enumerate(raw)
],
}
conversation.append(assistant_msg)
# Execute each tool call
for tc in raw:
func_name = sanitize_tool_name(
getattr(tc.function, "name", "")
if hasattr(tc, "function")
else ""
)
func_args_str = (
getattr(tc.function, "arguments", "{}")
if hasattr(tc, "function")
else "{}"
)
tool_call_id = getattr(tc, "id", f"call_{func_name}")
try:
func_args = json.loads(func_args_str)
except (json.JSONDecodeError, TypeError):
func_args = {}
record = ToolCallRecord(
name=func_name,
input=func_args,
)
if func_name in available_functions:
t0 = datetime.now()
started_at = t0
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=func_name,
tool_args=func_args,
from_agent=from_agent,
from_task=from_task,
),
)
try:
fn = available_functions[func_name]
tool_output = fn(**func_args)
t1 = datetime.now()
record.output = (
str(tool_output) if tool_output is not None else ""
)
record.duration_ms = (t1 - t0).total_seconds() * 1000
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=tool_output,
tool_name=func_name,
tool_args=func_args,
started_at=started_at,
finished_at=t1,
from_task=from_task,
from_agent=from_agent,
),
)
except Exception as e:
t1 = datetime.now()
record.output = f"Error: {e}"
record.duration_ms = (t1 - t0).total_seconds() * 1000
record.is_error = True
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=func_name,
tool_args=func_args,
error=str(e),
from_task=from_task,
from_agent=from_agent,
),
)
else:
record.output = f"Error: unknown function '{func_name}'"
record.is_error = True
result.tool_calls.append(record)
# Append tool result message for the model
conversation.append(
{
"role": "tool",
"tool_call_id": tool_call_id,
"content": record.output,
}
)
else:
# Unexpected return type — treat as final text
result.text = str(raw)
break
else:
# max_iterations exhausted — use last text or empty
if not result.text and result.tool_calls:
result.text = (
f"Max iterations ({max_iterations}) reached. "
f"Last tool: {result.tool_calls[-1].name}"
)
# Estimate cost
result.cost_usd = _estimate_cost_usd(
self.model,
result.usage.prompt_tokens,
result.usage.completion_tokens,
)
return result
def _accumulate_usage(self, result: LLMResult) -> None:
"""Pull token counts from the internal tracker into the LLMResult."""
tracker = getattr(self, "_token_usage", None)
if tracker and isinstance(tracker, dict):
result.usage.prompt_tokens = tracker.get("prompt_tokens", 0)
result.usage.completion_tokens = tracker.get("completion_tokens", 0)
result.usage.total_tokens = tracker.get("total_tokens", 0)
result.usage.successful_requests += 1
def _call_single(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Task | None = None,
from_agent: BaseAgent | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Single-shot LLM call (original call() logic)."""
with llm_call_context() as call_id:
crewai_event_bus.emit(
self,
@@ -2059,7 +1819,7 @@ class LLM(BaseLLM):
logging.info("Retrying LLM call without the unsupported 'stop'")
return self._call_single(
return self.call(
messages,
tools=tools,
callbacks=callbacks,

View File

@@ -1,112 +0,0 @@
"""Structured result types for LLM.call() with tool loop support.
When LLM.call() is invoked with tools and available_functions, it returns
an LLMResult instead of a plain string. This preserves backwards compatibility:
calls without tools still return str.
"""
from __future__ import annotations
from typing import Any
from pydantic import BaseModel, Field
from crewai.types.usage_metrics import UsageMetrics
class ToolCallRecord(BaseModel):
"""Record of a single tool call executed during an LLM tool loop.
Attributes:
name: The tool function name.
input: The arguments passed to the tool.
output: The string result returned by the tool.
duration_ms: Wall-clock time for the tool execution in milliseconds.
is_error: Whether the tool call raised an exception.
"""
name: str
input: dict[str, Any] = Field(default_factory=dict)
output: str = ""
duration_ms: float = 0.0
is_error: bool = False
class LLMResult(BaseModel):
"""Structured result from LLM.call() when tools are used.
Attributes:
text: The final text response from the model.
tool_calls: Ordered list of every tool call made during the loop.
usage: Aggregated token usage across all iterations.
cost_usd: Estimated cost in USD based on model pricing.
iterations: Number of LLM round-trips in the tool loop.
"""
text: str = ""
tool_calls: list[ToolCallRecord] = Field(default_factory=list)
usage: UsageMetrics = Field(default_factory=UsageMetrics)
cost_usd: float = 0.0
iterations: int = 0
# ---------------------------------------------------------------------------
# Simple cost estimation
# ---------------------------------------------------------------------------
# USD per 1M tokens. Covers major models. Inspired by Iris's pricing table.
PRICING: dict[str, dict[str, float]] = {
# Anthropic
"claude-opus-4-7": {"in": 5.00, "out": 25.00},
"claude-sonnet-4-6": {"in": 3.00, "out": 15.00},
"claude-sonnet-4-5": {"in": 3.00, "out": 15.00},
"claude-haiku-4-5": {"in": 1.00, "out": 5.00},
# OpenAI
"gpt-4o": {"in": 2.50, "out": 10.00},
"gpt-4o-mini": {"in": 0.15, "out": 0.60},
"gpt-4.1": {"in": 2.00, "out": 8.00},
"gpt-4.1-mini": {"in": 0.40, "out": 1.60},
"gpt-4.1-nano": {"in": 0.10, "out": 0.40},
"o1": {"in": 15.00, "out": 60.00},
"o1-mini": {"in": 3.00, "out": 12.00},
"o3": {"in": 2.00, "out": 8.00},
"o3-mini": {"in": 1.10, "out": 4.40},
"gpt-5": {"in": 1.25, "out": 10.00},
# Google Gemini
"gemini-2.5-pro": {"in": 1.25, "out": 10.00},
"gemini-2.5-flash": {"in": 0.30, "out": 2.50},
"gemini-2.0-flash": {"in": 0.10, "out": 0.40},
}
def _lookup_pricing(model: str) -> dict[str, float] | None:
"""Resolve a model name to its pricing row.
Handles provider prefixes (``anthropic/claude-sonnet-4-6``) and partial
matches (``claude-sonnet-4-6-20250514`` → ``claude-sonnet-4-6``).
"""
if not model:
return None
# Exact match
if model in PRICING:
return PRICING[model]
# Strip provider prefix
if "/" in model:
suffix = model.rsplit("/", 1)[1]
if suffix in PRICING:
return PRICING[suffix]
model = suffix
# Prefix / partial match
for key in PRICING:
if model.startswith(key) or key.startswith(model):
return PRICING[key]
return None
def estimate_cost_usd(model: str, prompt_tokens: int, completion_tokens: int) -> float:
"""Estimate the cost in USD for a given model and token counts."""
pricing = _lookup_pricing(model)
if not pricing:
return 0.0
return (
prompt_tokens * pricing["in"] + completion_tokens * pricing["out"]
) / 1_000_000

View File

@@ -27,7 +27,6 @@ from crewai.mcp.filters import (
create_static_tool_filter,
)
if TYPE_CHECKING:
from crewai.mcp.client import MCPClient
from crewai.mcp.tool_resolver import MCPToolResolver

View File

@@ -0,0 +1,82 @@
"""Wrap Flow classes as callable tools so agents can invoke them."""
from __future__ import annotations
import json
from typing import Any
from pydantic import BaseModel, Field
from crewai.tools.base_tool import BaseTool
from crewai.utilities.string_utils import sanitize_tool_name
class FlowToolInputSchema(BaseModel):
"""Default input schema for a FlowTool."""
inputs: str = Field(
default="{}",
description=(
"JSON string of key-value pairs to pass as inputs to the flow. "
"Use '{}' if the flow requires no inputs."
),
)
class FlowTool(BaseTool):
"""Wraps a Flow class as a BaseTool so an agent can invoke it.
The tool instantiates the Flow, calls ``kickoff(inputs=...)`` and returns
the result as a string.
"""
name: str = ""
description: str = ""
flow_class: Any = Field(
default=None,
description="The Flow class (not instance) to wrap.",
exclude=True,
)
args_schema: Any = FlowToolInputSchema
def _run(self, inputs: str = "{}") -> str:
"""Instantiate the Flow, run kickoff, and return the result."""
try:
parsed_inputs = json.loads(inputs) if isinstance(inputs, str) else inputs
except (json.JSONDecodeError, TypeError):
parsed_inputs = {}
if not isinstance(parsed_inputs, dict):
parsed_inputs = {}
flow_instance = self.flow_class()
result = flow_instance.kickoff(inputs=parsed_inputs if parsed_inputs else None)
return str(result)
def create_flow_tools(flows: list[type] | None) -> list[BaseTool]:
"""Convert a list of Flow classes into BaseTool wrappers.
Args:
flows: Flow classes (not instances) to wrap as tools.
Returns:
A list of FlowTool instances ready for agent use.
"""
if not flows:
return []
tools: list[BaseTool] = []
for flow_cls in flows:
name = sanitize_tool_name(flow_cls.__name__)
docstring = (flow_cls.__doc__ or "").strip()
description = docstring if docstring else f"Run the {flow_cls.__name__} flow."
tools.append(
FlowTool(
name=name,
description=description,
flow_class=flow_cls,
)
)
return tools

View File

@@ -0,0 +1,185 @@
"""Tests for Flow-as-tool functionality."""
from __future__ import annotations
from unittest.mock import MagicMock
from crewai.flow.flow import Flow, start
from crewai.tools.flow_tool import FlowTool, create_flow_tools
# ---------------------------------------------------------------------------
# Test Flow classes
# ---------------------------------------------------------------------------
class SimpleFlow(Flow):
"""A simple flow that greets the user."""
@start()
def greet(self) -> str:
return "Hello from SimpleFlow!"
class MathFlow(Flow):
"""Performs basic math operations."""
@start()
def compute(self) -> str:
return "42"
class NoDocFlow(Flow):
@start()
def run_it(self) -> str:
return "no doc"
# ---------------------------------------------------------------------------
# FlowTool unit tests
# ---------------------------------------------------------------------------
class TestFlowTool:
def test_wrap_simple_flow(self) -> None:
tool = FlowTool(
name="simple_flow",
description="A simple flow that greets the user.",
flow_class=SimpleFlow,
)
assert tool.name == "simple_flow"
assert "greets the user" in tool.description
def test_run_invokes_kickoff(self) -> None:
mock_flow = MagicMock()
mock_flow.return_value = mock_flow # __init__ returns self
mock_flow.kickoff.return_value = "mocked result"
tool = FlowTool(
name="test_flow",
description="test",
flow_class=mock_flow,
)
result = tool._run(inputs="{}")
assert result == "mocked result"
mock_flow.kickoff.assert_called_once()
def test_run_with_json_inputs(self) -> None:
mock_flow = MagicMock()
mock_flow.return_value = mock_flow
mock_flow.kickoff.return_value = "result with inputs"
tool = FlowTool(
name="test_flow",
description="test",
flow_class=mock_flow,
)
result = tool._run(inputs='{"key": "value"}')
assert result == "result with inputs"
mock_flow.kickoff.assert_called_once_with(inputs={"key": "value"})
def test_run_with_invalid_json_defaults_to_empty(self) -> None:
mock_flow = MagicMock()
mock_flow.return_value = mock_flow
mock_flow.kickoff.return_value = "ok"
tool = FlowTool(
name="test_flow",
description="test",
flow_class=mock_flow,
)
result = tool._run(inputs="not valid json")
assert result == "ok"
mock_flow.kickoff.assert_called_once_with(inputs=None)
def test_run_returns_string(self) -> None:
mock_flow = MagicMock()
mock_flow.return_value = mock_flow
mock_flow.kickoff.return_value = 42
tool = FlowTool(
name="test_flow",
description="test",
flow_class=mock_flow,
)
result = tool._run()
assert result == "42"
assert isinstance(result, str)
# ---------------------------------------------------------------------------
# create_flow_tools tests
# ---------------------------------------------------------------------------
class TestCreateFlowTools:
def test_creates_tools_from_flow_classes(self) -> None:
tools = create_flow_tools([SimpleFlow, MathFlow])
assert len(tools) == 2
names = {t.name for t in tools}
assert "simple_flow" in names
assert "math_flow" in names
def test_description_from_docstring(self) -> None:
tools = create_flow_tools([SimpleFlow])
assert len(tools) == 1
assert "greets the user" in tools[0].description
def test_description_fallback_when_no_docstring(self) -> None:
tools = create_flow_tools([NoDocFlow])
assert len(tools) == 1
assert "NoDocFlow" in tools[0].description
def test_empty_list_returns_empty(self) -> None:
assert create_flow_tools([]) == []
def test_none_returns_empty(self) -> None:
assert create_flow_tools(None) == []
def test_tools_are_base_tool_instances(self) -> None:
from crewai.tools.base_tool import BaseTool
tools = create_flow_tools([SimpleFlow])
for tool in tools:
assert isinstance(tool, BaseTool)
# ---------------------------------------------------------------------------
# Agent integration tests
# ---------------------------------------------------------------------------
class TestAgentFlowIntegration:
def test_agent_with_flows_has_flow_tools(self) -> None:
from crewai.agent.core import Agent
agent = Agent(
role="Test Agent",
goal="Test flows",
backstory="I test things",
flows=[SimpleFlow, MathFlow],
)
tool_names = {t.name for t in (agent.tools or [])}
assert "simple_flow" in tool_names
assert "math_flow" in tool_names
def test_agent_without_flows_no_extra_tools(self) -> None:
from crewai.agent.core import Agent
agent = Agent(
role="Test Agent",
goal="Test",
backstory="I test things",
)
# Should not have any flow tools
flow_tool_names = {
t.name for t in (agent.tools or []) if isinstance(t, FlowTool)
}
assert len(flow_tool_names) == 0
def test_flow_tool_executes_real_flow(self) -> None:
"""Test that a FlowTool actually runs the Flow's kickoff."""
tools = create_flow_tools([SimpleFlow])
tool = tools[0]
result = tool.run(inputs="{}")
assert "Hello from SimpleFlow" in result

View File

@@ -1,411 +0,0 @@
"""Tests for LLM.call() tool loop and LLMResult.
All LLM calls are mocked — no real API traffic.
"""
from __future__ import annotations
import json
from types import SimpleNamespace
from typing import Any
from unittest.mock import MagicMock, patch
import pytest
from crewai.llm_result import (
LLMResult,
ToolCallRecord,
_lookup_pricing,
estimate_cost_usd,
)
def _make_litellm_llm(model: str = "gpt-4o") -> Any:
"""Create an LLM instance that uses the litellm fallback path."""
from crewai.llm import LLM
return LLM(model=model, is_litellm=True)
# ---------------------------------------------------------------------------
# Helpers
# ---------------------------------------------------------------------------
def _make_tool_call(name: str, arguments: dict, call_id: str = "call_1"):
"""Build a tool-call object using litellm's actual types."""
try:
from litellm.types.utils import (
ChatCompletionMessageToolCall,
Function,
)
return ChatCompletionMessageToolCall(
id=call_id,
function=Function(name=name, arguments=json.dumps(arguments)),
type="function",
)
except ImportError:
func = SimpleNamespace(name=name, arguments=json.dumps(arguments))
return SimpleNamespace(id=call_id, function=func, type="function")
def _make_model_response(content: str | None = None, tool_calls: list | None = None):
"""Build a minimal mock ModelResponse that passes isinstance checks.
We need it to be an instance of litellm's ModelResponse/ModelResponseBase
so the internal isinstance() checks work. We import those types when
litellm is available.
"""
try:
from litellm.types.utils import (
Choices,
Message,
ModelResponse,
Usage,
)
message = Message(content=content, tool_calls=tool_calls or None)
choice = Choices(message=message, finish_reason="stop", index=0)
resp = ModelResponse(
choices=[choice],
usage=Usage(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
),
)
return resp
except ImportError:
# Fallback to SimpleNamespace if litellm not installed
message = SimpleNamespace(content=content, tool_calls=tool_calls or [])
choice = SimpleNamespace(message=message, finish_reason="stop")
usage = SimpleNamespace(
prompt_tokens=100,
completion_tokens=50,
total_tokens=150,
)
resp = SimpleNamespace(
choices=[choice],
model_extra={"usage": usage},
)
return resp
DUMMY_TOOL_SCHEMA = [
{
"type": "function",
"function": {
"name": "get_weather",
"description": "Get weather for a city",
"parameters": {
"type": "object",
"properties": {
"city": {"type": "string"},
},
"required": ["city"],
},
},
}
]
# ---------------------------------------------------------------------------
# Unit tests for LLMResult / ToolCallRecord
# ---------------------------------------------------------------------------
class TestLLMResultModels:
def test_tool_call_record_defaults(self):
r = ToolCallRecord(name="foo")
assert r.input == {}
assert r.output == ""
assert r.duration_ms == 0.0
assert r.is_error is False
def test_llm_result_defaults(self):
r = LLMResult()
assert r.text == ""
assert r.tool_calls == []
assert r.cost_usd == 0.0
assert r.iterations == 0
assert r.usage.total_tokens == 0
def test_llm_result_with_data(self):
r = LLMResult(
text="hello",
tool_calls=[ToolCallRecord(name="foo", input={"a": 1}, output="bar")],
iterations=2,
cost_usd=0.005,
)
assert r.text == "hello"
assert len(r.tool_calls) == 1
assert r.tool_calls[0].name == "foo"
# ---------------------------------------------------------------------------
# Cost estimation
# ---------------------------------------------------------------------------
class TestCostEstimation:
def test_known_model(self):
cost = estimate_cost_usd("gpt-4o", prompt_tokens=1_000_000, completion_tokens=0)
assert cost == pytest.approx(2.50)
def test_known_model_output(self):
cost = estimate_cost_usd("gpt-4o", prompt_tokens=0, completion_tokens=1_000_000)
assert cost == pytest.approx(10.00)
def test_unknown_model_returns_zero(self):
cost = estimate_cost_usd("some-random-model-xyz", 1000, 1000)
assert cost == 0.0
def test_provider_prefix_stripped(self):
cost = estimate_cost_usd("anthropic/claude-sonnet-4-6", 1_000_000, 0)
assert cost == pytest.approx(3.00)
def test_partial_match(self):
# "claude-sonnet-4-6-20250514" should match "claude-sonnet-4-6"
cost = estimate_cost_usd("claude-sonnet-4-6-20250514", 1_000_000, 0)
assert cost == pytest.approx(3.00)
def test_lookup_none(self):
assert _lookup_pricing("") is None
assert _lookup_pricing("nonexistent") is None
# ---------------------------------------------------------------------------
# LLM.call() backwards compatibility (no tools → returns str)
# ---------------------------------------------------------------------------
class TestCallBackwardsCompat:
"""LLM.call() without tools must return str exactly as before."""
@patch("crewai.llm.litellm")
def test_call_without_tools_returns_str(self, mock_litellm):
"""Plain call without tools should return a string."""
mock_litellm.completion.return_value = _make_model_response(content="Hello world")
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
llm = _make_litellm_llm()
result = llm.call("Say hello")
assert isinstance(result, str)
assert result == "Hello world"
# ---------------------------------------------------------------------------
# LLM.call() with tools → returns LLMResult
# ---------------------------------------------------------------------------
class TestCallWithToolLoop:
"""When tools + available_functions are passed, call() returns LLMResult."""
@patch("crewai.llm.litellm")
def test_single_tool_call_then_text(self, mock_litellm):
"""Model calls one tool, then responds with text."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
# First call: model wants to call get_weather
tool_call = _make_tool_call("get_weather", {"city": "SF"})
resp1 = _make_model_response(content=None, tool_calls=[tool_call])
# Second call: model responds with text
resp2 = _make_model_response(content="It's sunny in SF!")
mock_litellm.completion.side_effect = [resp1, resp2]
llm = _make_litellm_llm()
def get_weather(city: str) -> str:
return f"Sunny, 72°F in {city}"
result = llm.call(
messages="What's the weather in SF?",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": get_weather},
)
assert isinstance(result, LLMResult)
assert result.text == "It's sunny in SF!"
assert len(result.tool_calls) == 1
assert result.tool_calls[0].name == "get_weather"
assert result.tool_calls[0].input == {"city": "SF"}
assert "Sunny" in result.tool_calls[0].output
assert result.tool_calls[0].is_error is False
assert result.iterations == 2
@patch("crewai.llm.litellm")
def test_multiple_tool_calls_in_sequence(self, mock_litellm):
"""Model calls two tools across two iterations."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
tc1 = _make_tool_call("get_weather", {"city": "SF"}, "call_1")
resp1 = _make_model_response(content=None, tool_calls=[tc1])
tc2 = _make_tool_call("get_weather", {"city": "NYC"}, "call_2")
resp2 = _make_model_response(content=None, tool_calls=[tc2])
resp3 = _make_model_response(content="SF is sunny, NYC is rainy.")
mock_litellm.completion.side_effect = [resp1, resp2, resp3]
llm = _make_litellm_llm()
def get_weather(city: str) -> str:
return f"Weather for {city}: fine"
result = llm.call(
messages="Compare SF and NYC weather",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": get_weather},
)
assert isinstance(result, LLMResult)
assert len(result.tool_calls) == 2
assert result.tool_calls[0].input["city"] == "SF"
assert result.tool_calls[1].input["city"] == "NYC"
assert result.iterations == 3
@patch("crewai.llm.litellm")
def test_max_iterations_stops_loop(self, mock_litellm):
"""Loop stops when max_iterations is reached."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
# Model always wants to call a tool — never stops
def make_tool_resp():
tc = _make_tool_call("get_weather", {"city": "SF"})
return _make_model_response(content=None, tool_calls=[tc])
mock_litellm.completion.side_effect = [make_tool_resp() for _ in range(5)]
llm = _make_litellm_llm()
result = llm.call(
messages="Loop forever",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": lambda city: "sunny"},
max_iterations=3,
)
assert isinstance(result, LLMResult)
assert result.iterations == 3
assert len(result.tool_calls) == 3
# Should have a text noting max iterations
assert "Max iterations" in result.text
@patch("crewai.llm.litellm")
def test_tool_error_handling(self, mock_litellm):
"""Tool that raises an exception is captured in the record."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
tc = _make_tool_call("get_weather", {"city": "SF"})
resp1 = _make_model_response(content=None, tool_calls=[tc])
resp2 = _make_model_response(content="Sorry, couldn't get weather.")
mock_litellm.completion.side_effect = [resp1, resp2]
llm = _make_litellm_llm()
def broken_weather(city: str) -> str:
raise RuntimeError("API down")
result = llm.call(
messages="Weather?",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": broken_weather},
)
assert isinstance(result, LLMResult)
assert len(result.tool_calls) == 1
assert result.tool_calls[0].is_error is True
assert "API down" in result.tool_calls[0].output
assert result.text == "Sorry, couldn't get weather."
@patch("crewai.llm.litellm")
def test_unknown_function_error(self, mock_litellm):
"""Tool call for a function not in available_functions."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
tc = _make_tool_call("nonexistent_tool", {})
resp1 = _make_model_response(content=None, tool_calls=[tc])
resp2 = _make_model_response(content="I couldn't find that tool.")
mock_litellm.completion.side_effect = [resp1, resp2]
llm = _make_litellm_llm()
result = llm.call(
messages="Do something",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": lambda city: "sunny"},
)
assert isinstance(result, LLMResult)
assert result.tool_calls[0].is_error is True
assert "unknown function" in result.tool_calls[0].output
@patch("crewai.llm.litellm")
def test_cost_estimation_populated(self, mock_litellm):
"""cost_usd is populated from token usage and model pricing."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
resp = _make_model_response(content="Done!")
mock_litellm.completion.return_value = resp
llm = _make_litellm_llm()
result = llm.call(
messages="Hello",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": lambda city: "sunny"},
)
assert isinstance(result, LLMResult)
# cost_usd should be >= 0 (may be 0 if usage tracking didn't fire,
# but the field should exist and be a float)
assert isinstance(result.cost_usd, float)
@patch("crewai.llm.litellm")
def test_immediate_text_response_with_tools(self, mock_litellm):
"""Model responds with text on first call (no tool use)."""
mock_litellm.drop_params = True
mock_litellm.suppress_debug_info = True
mock_litellm.success_callback = []
mock_litellm._async_success_callback = []
mock_litellm.callbacks = []
resp = _make_model_response(content="I know the answer already.")
mock_litellm.completion.return_value = resp
llm = _make_litellm_llm()
result = llm.call(
messages="What's 2+2?",
tools=DUMMY_TOOL_SCHEMA,
available_functions={"get_weather": lambda city: "sunny"},
)
assert isinstance(result, LLMResult)
assert result.text == "I know the answer already."
assert len(result.tool_calls) == 0
assert result.iterations == 1

View File

@@ -164,7 +164,7 @@ info = "Commits must follow Conventional Commits 1.0.0."
[tool.uv]
# Pinned to include the security patch releases (authlib 1.6.11,
# langchain-text-splitters 1.1.2) uploaded on 2026-04-16.
exclude-newer = "2026-04-26"
exclude-newer = "2026-04-22"
# composio-core pins rich<14 but textual requires rich>=14.
# onnxruntime 1.24+ dropped Python 3.10 wheels; cap it so qdrant[fastembed] resolves on 3.10.