Merge branch 'main' into perf/reduce-framework-overhead

This commit is contained in:
Greyson LaLonde
2026-04-02 00:04:39 +08:00
committed by GitHub
71 changed files with 4418 additions and 1607 deletions

View File

@@ -152,4 +152,4 @@ __all__ = [
"wrap_file_source",
]
__version__ = "1.13.0rc1"
__version__ = "1.13.0a5"

View File

@@ -11,7 +11,7 @@ dependencies = [
"pytube~=15.0.0",
"requests~=2.32.5",
"docker~=7.1.0",
"crewai==1.13.0rc1",
"crewai==1.13.0a5",
"tiktoken~=0.8.0",
"beautifulsoup4~=4.13.4",
"python-docx~=1.2.0",

View File

@@ -309,4 +309,4 @@ __all__ = [
"ZapierActionTools",
]
__version__ = "1.13.0rc1"
__version__ = "1.13.0a5"

View File

@@ -14281,10 +14281,349 @@
],
"title": "EnvVar",
"type": "object"
},
"JsonResponseFormat": {
"description": "Response format requesting raw JSON output (e.g. ``{\"type\": \"json_object\"}``).",
"properties": {
"type": {
"const": "json_object",
"title": "Type",
"type": "string"
}
},
"required": [
"type"
],
"title": "JsonResponseFormat",
"type": "object"
},
"LLM": {
"properties": {
"additional_params": {
"additionalProperties": true,
"title": "Additional Params",
"type": "object"
},
"api_base": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Api Base"
},
"api_key": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Api Key"
},
"api_version": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Api Version"
},
"base_url": {
"anyOf": [
{
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Base Url"
},
"callbacks": {
"anyOf": [
{
"items": {},
"type": "array"
},
{
"type": "null"
}
],
"default": null,
"title": "Callbacks"
},
"completion_cost": {
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Completion Cost"
},
"context_window_size": {
"default": 0,
"title": "Context Window Size",
"type": "integer"
},
"frequency_penalty": {
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Frequency Penalty"
},
"interceptor": {
"default": null,
"title": "Interceptor"
},
"is_anthropic": {
"default": false,
"title": "Is Anthropic",
"type": "boolean"
},
"is_litellm": {
"default": false,
"title": "Is Litellm",
"type": "boolean"
},
"logit_bias": {
"anyOf": [
{
"additionalProperties": {
"type": "number"
},
"type": "object"
},
{
"type": "null"
}
],
"default": null,
"title": "Logit Bias"
},
"logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "Logprobs"
},
"max_completion_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "Max Completion Tokens"
},
"max_tokens": {
"anyOf": [
{
"type": "integer"
},
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Max Tokens"
},
"model": {
"title": "Model",
"type": "string"
},
"n": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "N"
},
"prefer_upload": {
"default": false,
"title": "Prefer Upload",
"type": "boolean"
},
"presence_penalty": {
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Presence Penalty"
},
"provider": {
"default": "openai",
"title": "Provider",
"type": "string"
},
"reasoning_effort": {
"anyOf": [
{
"enum": [
"none",
"low",
"medium",
"high"
],
"type": "string"
},
{
"type": "null"
}
],
"default": null,
"title": "Reasoning Effort"
},
"response_format": {
"anyOf": [
{
"$ref": "#/$defs/JsonResponseFormat"
},
{},
{
"type": "null"
}
],
"default": null,
"title": "Response Format"
},
"seed": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "Seed"
},
"stop": {
"items": {
"type": "string"
},
"title": "Stop",
"type": "array"
},
"stream": {
"default": false,
"title": "Stream",
"type": "boolean"
},
"temperature": {
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Temperature"
},
"thinking": {
"default": null,
"title": "Thinking"
},
"timeout": {
"anyOf": [
{
"type": "number"
},
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "Timeout"
},
"top_logprobs": {
"anyOf": [
{
"type": "integer"
},
{
"type": "null"
}
],
"default": null,
"title": "Top Logprobs"
},
"top_p": {
"anyOf": [
{
"type": "number"
},
{
"type": "null"
}
],
"default": null,
"title": "Top P"
}
},
"required": [
"model"
],
"title": "LLM",
"type": "object"
}
},
"description": "A tool for performing Optical Character Recognition on images.\n\nThis tool leverages LLMs to extract text from images. It can process\nboth local image files and images available via URLs.\n\nAttributes:\n name (str): Name of the tool.\n description (str): Description of the tool's functionality.\n args_schema (Type[BaseModel]): Pydantic schema for input validation.\n\nPrivate Attributes:\n _llm (Optional[LLM]): Language model instance for making API calls.",
"properties": {},
"properties": {
"llm": {
"$ref": "#/$defs/LLM"
}
},
"title": "OCRTool",
"type": "object"
},

View File

@@ -43,7 +43,7 @@ dependencies = [
"uv~=0.9.13",
"aiosqlite~=0.21.0",
"pyyaml~=6.0",
"lancedb>=0.29.2",
"lancedb>=0.29.2,<0.30.1",
]
[project.urls]
@@ -54,7 +54,7 @@ Repository = "https://github.com/crewAIInc/crewAI"
[project.optional-dependencies]
tools = [
"crewai-tools==1.13.0rc1",
"crewai-tools==1.13.0a5",
]
embeddings = [
"tiktoken~=0.8.0"

View File

@@ -4,6 +4,8 @@ from typing import Any
import urllib.request
import warnings
from pydantic import PydanticUserError
from crewai.agent.core import Agent
from crewai.agent.planning_config import PlanningConfig
from crewai.crew import Crew
@@ -42,7 +44,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.13.0rc1"
__version__ = "1.13.0a5"
_telemetry_submitted = False
@@ -93,6 +95,38 @@ def __getattr__(name: str) -> Any:
raise AttributeError(f"module 'crewai' has no attribute {name!r}")
try:
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.experimental.agent_executor import AgentExecutor as _AgentExecutor
from crewai.hooks.llm_hooks import LLMCallHookContext as _LLMCallHookContext
from crewai.tools.tool_types import ToolResult as _ToolResult
from crewai.utilities.prompts import (
StandardPromptResult as _StandardPromptResult,
SystemPromptResult as _SystemPromptResult,
)
_AgentExecutor.model_rebuild(
force=True,
_types_namespace={
"Agent": Agent,
"ToolsHandler": _ToolsHandler,
"Crew": Crew,
"BaseLLM": BaseLLM,
"Task": Task,
"StandardPromptResult": _StandardPromptResult,
"SystemPromptResult": _SystemPromptResult,
"LLMCallHookContext": _LLMCallHookContext,
"ToolResult": _ToolResult,
},
)
except (ImportError, PydanticUserError):
import logging as _logging
_logging.getLogger(__name__).warning(
"AgentExecutor.model_rebuild() failed; forward refs may be unresolved.",
exc_info=True,
)
__all__ = [
"LLM",
"Agent",

View File

@@ -25,7 +25,6 @@ from pydantic import (
BaseModel,
ConfigDict,
Field,
InstanceOf,
PrivateAttr,
model_validator,
)
@@ -167,10 +166,10 @@ class Agent(BaseAgent):
default=True,
description="Use system prompt for the agent.",
)
llm: str | InstanceOf[BaseLLM] | None = Field(
llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
function_calling_llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
system_template: str | None = Field(
@@ -1012,7 +1011,7 @@ class Agent(BaseAgent):
self.agent_executor.tools = tools
self.agent_executor.original_tools = raw_tools
self.agent_executor.prompt = prompt
self.agent_executor.stop = stop_words
self.agent_executor.stop_words = stop_words
self.agent_executor.tools_names = get_tool_names(tools)
self.agent_executor.tools_description = render_text_description_and_args(tools)
self.agent_executor.response_model = (

View File

@@ -12,7 +12,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -185,7 +184,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
default=None,
description="Knowledge sources for the agent.",
)
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
knowledge_storage: BaseKnowledgeStorage | None = Field(
default=None,
description="Custom knowledge storage for the agent.",
)

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.13.0rc1"
"crewai[tools]==1.13.0a5"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.13.0rc1"
"crewai[tools]==1.13.0a5"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "Power up your crews with {{folder_name}}"
readme = "README.md"
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.13.0rc1"
"crewai[tools]==1.13.0a5"
]
[tool.crewai]

View File

@@ -22,7 +22,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
Json,
PrivateAttr,
field_validator,
@@ -176,7 +175,7 @@ class Crew(FlowTrackable, BaseModel):
_rpm_controller: RPMController = PrivateAttr()
_logger: Logger = PrivateAttr()
_file_handler: FileHandler = PrivateAttr()
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
_train: bool | None = PrivateAttr(default=False)
_train_iteration: int | None = PrivateAttr()
@@ -210,13 +209,13 @@ class Crew(FlowTrackable, BaseModel):
default=None,
description="Metrics for the LLM usage during all tasks execution.",
)
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
manager_llm: str | BaseLLM | None = Field(
description="Language model that will run the agent.", default=None
)
manager_agent: BaseAgent | None = Field(
description="Custom agent that will be used as manager.", default=None
)
function_calling_llm: str | InstanceOf[LLM] | None = Field(
function_calling_llm: str | LLM | None = Field(
description="Language model that will run the agent.", default=None
)
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
@@ -267,7 +266,7 @@ class Crew(FlowTrackable, BaseModel):
default=False,
description="Plan the crew execution and add the plan to the crew.",
)
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
planning_llm: str | BaseLLM | Any | None = Field(
default=None,
description=(
"Language model that will run the AgentPlanner if planning is True."
@@ -288,7 +287,7 @@ class Crew(FlowTrackable, BaseModel):
"knowledge object."
),
)
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
chat_llm: str | BaseLLM | Any | None = Field(
default=None,
description="LLM used to handle chatting with the crew.",
)
@@ -1800,7 +1799,7 @@ class Crew(FlowTrackable, BaseModel):
def test(
self,
n_iterations: int,
eval_llm: str | InstanceOf[BaseLLM],
eval_llm: str | BaseLLM,
inputs: dict[str, Any] | None = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations.

View File

@@ -57,6 +57,7 @@ class LLMCallCompletedEvent(LLMEventBase):
messages: str | list[dict[str, Any]] | None = None
response: Any
call_type: LLMCallType
usage: dict[str, Any] | None = None
class LLMCallFailedEvent(LLMEventBase):

View File

@@ -11,10 +11,15 @@ import threading
from typing import TYPE_CHECKING, Any, Literal, TypeVar, cast
from uuid import uuid4
from pydantic import BaseModel, Field, GetCoreSchemaHandler
from pydantic_core import CoreSchema, core_schema
from pydantic import (
BaseModel,
Field,
PrivateAttr,
model_validator,
)
from rich.console import Console
from rich.text import Text
from typing_extensions import Self
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
from crewai.agents.parser import (
@@ -119,6 +124,7 @@ class AgentExecutorState(BaseModel):
(todos, observations, replan tracking) in a single validated model.
"""
id: str = Field(default_factory=lambda: str(uuid4()))
messages: list[LLMMessage] = Field(default_factory=list)
iterations: int = Field(default=0)
current_answer: AgentAction | AgentFinish | None = Field(default=None)
@@ -152,6 +158,9 @@ class AgentExecutorState(BaseModel):
class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""Agent Executor for both standalone agents and crew-bound agents.
_skip_auto_memory prevents Flow from eagerly allocating a Memory
instance — the executor uses agent/crew memory, not its own.
Inherits from:
- Flow[AgentExecutorState]: Provides flow orchestration capabilities
- CrewAgentExecutorMixin: Provides memory methods (short/long/external term)
@@ -159,136 +168,74 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
This executor can operate in two modes:
- Standalone mode: When crew and task are None (used by Agent.kickoff())
- Crew mode: When crew and task are provided (used by Agent.execute_task())
Note: Multiple instances may be created during agent initialization
(cache setup, RPM controller setup, etc.) but only the final instance
should execute tasks via invoke().
"""
def __init__(
self,
llm: BaseLLM,
agent: Agent,
prompt: SystemPromptResult | StandardPromptResult,
max_iter: int,
tools: list[CrewStructuredTool],
tools_names: str,
stop_words: list[str],
tools_description: str,
tools_handler: ToolsHandler,
task: Task | None = None,
crew: Crew | None = None,
step_callback: Any = None,
original_tools: list[BaseTool] | None = None,
function_calling_llm: BaseLLM | Any | None = None,
respect_context_window: bool = False,
request_within_rpm_limit: Callable[[], bool] | None = None,
callbacks: list[Any] | None = None,
response_model: type[BaseModel] | None = None,
i18n: I18N | None = None,
) -> None:
"""Initialize the flow-based agent executor.
_skip_auto_memory: bool = True
Args:
llm: Language model instance.
agent: Agent to execute.
prompt: Prompt templates.
max_iter: Maximum iterations.
tools: Available tools.
tools_names: Tool names string.
stop_words: Stop word list.
tools_description: Tool descriptions.
tools_handler: Tool handler instance.
task: Optional task to execute (None for standalone agent execution).
crew: Optional crew instance (None for standalone agent execution).
step_callback: Optional step callback.
original_tools: Original tool list.
function_calling_llm: Optional function calling LLM.
respect_context_window: Respect context limits.
request_within_rpm_limit: RPM limit check function.
callbacks: Optional callbacks list.
response_model: Optional Pydantic model for structured outputs.
"""
self._i18n: I18N = i18n or get_i18n()
self.llm = llm
self.task: Task | None = task
self.agent = agent
self.crew: Crew | None = crew
self.prompt = prompt
self.tools = tools
self.tools_names = tools_names
self.stop = stop_words
self.max_iter = max_iter
self.callbacks = callbacks or []
self._printer: Printer = Printer()
self.tools_handler = tools_handler
self.original_tools = original_tools or []
self.step_callback = step_callback
self.tools_description = tools_description
self.function_calling_llm = function_calling_llm
self.respect_context_window = respect_context_window
self.request_within_rpm_limit = request_within_rpm_limit
self.response_model = response_model
self.log_error_after = 3
self._console: Console = Console()
suppress_flow_events: bool = True # always suppress for executor
llm: BaseLLM = Field(exclude=True)
agent: Agent = Field(exclude=True)
prompt: SystemPromptResult | StandardPromptResult = Field(exclude=True)
max_iter: int = Field(default=25, exclude=True)
tools: list[CrewStructuredTool] = Field(default_factory=list, exclude=True)
tools_names: str = Field(default="", exclude=True)
stop_words: list[str] = Field(default_factory=list, exclude=True)
tools_description: str = Field(default="", exclude=True)
tools_handler: ToolsHandler | None = Field(default=None, exclude=True)
task: Task | None = Field(default=None, exclude=True)
crew: Crew | None = Field(default=None, exclude=True)
step_callback: Any = Field(default=None, exclude=True)
original_tools: list[BaseTool] = Field(default_factory=list, exclude=True)
function_calling_llm: BaseLLM | None = Field(default=None, exclude=True)
respect_context_window: bool = Field(default=False, exclude=True)
request_within_rpm_limit: Callable[[], bool] | None = Field(
default=None, exclude=True
)
callbacks: list[Any] = Field(default_factory=list, exclude=True)
response_model: type[BaseModel] | None = Field(default=None, exclude=True)
i18n: I18N | None = Field(default=None, exclude=True)
log_error_after: int = Field(default=3, exclude=True)
before_llm_call_hooks: list[BeforeLLMCallHookType | BeforeLLMCallHookCallable] = (
Field(default_factory=list, exclude=True)
)
after_llm_call_hooks: list[AfterLLMCallHookType | AfterLLMCallHookCallable] = Field(
default_factory=list, exclude=True
)
# Error context storage for recovery
self._last_parser_error: OutputParserError | None = None
self._last_context_error: Exception | None = None
_i18n: I18N = PrivateAttr(default_factory=get_i18n)
_printer: Printer = PrivateAttr(default_factory=Printer)
_console: Console = PrivateAttr(default_factory=Console)
_last_parser_error: OutputParserError | None = PrivateAttr(default=None)
_last_context_error: Exception | None = PrivateAttr(default=None)
_execution_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_finalize_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_finalize_called: bool = PrivateAttr(default=False)
_is_executing: bool = PrivateAttr(default=False)
_has_been_invoked: bool = PrivateAttr(default=False)
_instance_id: str = PrivateAttr(default_factory=lambda: str(uuid4())[:8])
_step_executor: Any = PrivateAttr(default=None)
_planner_observer: Any = PrivateAttr(default=None)
# Execution guard to prevent concurrent/duplicate executions
self._execution_lock = threading.Lock()
self._finalize_lock = threading.Lock()
self._finalize_called: bool = False
self._is_executing: bool = False
self._has_been_invoked: bool = False
self._flow_initialized: bool = False
self._instance_id = str(uuid4())[:8]
self.before_llm_call_hooks: list[
BeforeLLMCallHookType | BeforeLLMCallHookCallable
] = []
self.after_llm_call_hooks: list[
AfterLLMCallHookType | AfterLLMCallHookCallable
] = []
@model_validator(mode="after")
def _setup_executor(self) -> Self:
"""Configure executor after Pydantic field initialization."""
self._i18n = self.i18n or get_i18n()
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
existing_stop = getattr(self.llm, "stop", [])
self.llm.stop = list(
set(
existing_stop + self.stop
if isinstance(existing_stop, list)
else self.stop
)
)
if not isinstance(existing_stop, list):
existing_stop = []
self.llm.stop = list(set(existing_stop + self.stop_words))
self._state = AgentExecutorState()
self.max_method_calls = self.max_iter * 10
# Plan-and-Execute components (Phase 2)
# Lazy-imported to avoid circular imports during module load
self._step_executor: Any = None
self._planner_observer: Any = None
def _ensure_flow_initialized(self) -> None:
"""Ensure Flow.__init__() has been called.
This is deferred from __init__ to prevent FlowCreatedEvent emission
during agent setup when multiple executor instances are created.
Only the instance that actually executes via invoke() will emit events.
"""
if not self._flow_initialized:
current_tracing = is_tracing_enabled_in_context()
# Now call Flow's __init__ which will replace self._state
# with Flow's managed state. Suppress flow events since this is
# an agent executor, not a user-facing flow.
super().__init__(
suppress_flow_events=True,
tracing=current_tracing if current_tracing else None,
max_method_calls=self.max_iter * 10,
)
self._flow_initialized = True
current_tracing = is_tracing_enabled_in_context()
self.tracing = current_tracing if current_tracing else None
self._flow_post_init()
return self
def _check_native_tool_support(self) -> bool:
"""Check if LLM supports native function calling."""
@@ -318,19 +265,13 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
@property
def state(self) -> AgentExecutorState:
"""Get state - returns temporary state if Flow not yet initialized.
Flow initialization is deferred to prevent event emission during agent setup.
Returns the temporary state until invoke() is called.
"""
if self._flow_initialized and hasattr(self, "_state_lock"):
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
return self._state
"""Get thread-safe state proxy."""
return StateProxy(self._state, self._state_lock) # type: ignore[return-value]
@property
def iterations(self) -> int:
"""Compatibility property for mixin - returns state iterations."""
return self._state.iterations
return self._state.iterations # type: ignore[no-any-return]
@iterations.setter
def iterations(self, value: int) -> None:
@@ -340,7 +281,7 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
@property
def messages(self) -> list[LLMMessage]:
"""Compatibility property - returns state messages."""
return self._state.messages
return self._state.messages # type: ignore[no-any-return]
@messages.setter
def messages(self, value: list[LLMMessage]) -> None:
@@ -1966,42 +1907,10 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"original_tool": original_tool,
}
def _extract_tool_name(self, tool_call: Any) -> str:
"""Extract tool name from various tool call formats."""
if hasattr(tool_call, "function"):
return sanitize_tool_name(tool_call.function.name)
if hasattr(tool_call, "function_call") and tool_call.function_call:
return sanitize_tool_name(tool_call.function_call.name)
if hasattr(tool_call, "name"):
return sanitize_tool_name(tool_call.name)
if isinstance(tool_call, dict):
func_info = tool_call.get("function", {})
return sanitize_tool_name(
func_info.get("name", "") or tool_call.get("name", "unknown")
)
return "unknown"
@router(execute_native_tool)
def check_native_todo_completion(
self,
) -> Literal["todo_satisfied", "todo_not_satisfied"]:
"""Check if the native tool execution satisfied the active todo.
Similar to check_todo_completion but for native tool execution path.
"""
current_todo = self.state.todos.current_todo
if not current_todo:
return "todo_not_satisfied"
# For native tools, any tool execution satisfies the todo
return "todo_satisfied"
@listen("initialized")
def continue_iteration(self) -> Literal["check_iteration"]:
"""Bridge listener that connects iteration loop back to iteration check."""
if self._flow_initialized:
self._discard_or_listener(FlowMethodName("continue_iteration"))
self._discard_or_listener(FlowMethodName("continue_iteration"))
return "check_iteration"
@router(or_(initialize_reasoning, continue_iteration))
@@ -2629,8 +2538,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
if is_inside_event_loop():
return self.invoke_async(inputs)
self._ensure_flow_initialized()
with self._execution_lock:
if self._is_executing:
raise RuntimeError(
@@ -2721,8 +2628,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
Returns:
Dictionary with agent output.
"""
self._ensure_flow_initialized()
with self._execution_lock:
if self._is_executing:
raise RuntimeError(
@@ -3038,17 +2943,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
"""
return bool(self.crew and self.crew._train)
@classmethod
def __get_pydantic_core_schema__(
cls, _source_type: Any, _handler: GetCoreSchemaHandler
) -> CoreSchema:
"""Generate Pydantic core schema for Protocol compatibility.
Allows the executor to be used in Pydantic models without
requiring arbitrary_types_allowed=True.
"""
return core_schema.any_schema()
# Backward compatibility alias (deprecated)
CrewAgentExecutorFlow = AgentExecutor

View File

@@ -39,7 +39,14 @@ from uuid import uuid4
from opentelemetry import baggage
from opentelemetry.context import attach, detach
from pydantic import BaseModel, Field, ValidationError
from pydantic import (
BaseModel,
ConfigDict,
Field,
PrivateAttr,
ValidationError,
)
from pydantic._internal._model_construction import ModelMetaclass
from rich.console import Console
from rich.panel import Panel
@@ -81,6 +88,7 @@ from crewai.flow.flow_wrappers import (
SimpleFlowCondition,
StartMethod,
)
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.flow.input_provider import InputProvider
from crewai.flow.persistence.base import FlowPersistence
from crewai.flow.types import (
@@ -108,7 +116,6 @@ if TYPE_CHECKING:
from crewai_files import FileInput
from crewai.flow.async_feedback.types import PendingFeedbackContext
from crewai.flow.human_feedback import HumanFeedbackResult
from crewai.llms.base_llm import BaseLLM
from crewai.flow.visualization import build_flow_structure, render_interactive
@@ -728,7 +735,7 @@ class StateProxy(Generic[T]):
return result
class FlowMeta(type):
class FlowMeta(ModelMetaclass):
def __new__(
mcs,
name: str,
@@ -736,6 +743,45 @@ class FlowMeta(type):
namespace: dict[str, Any],
**kwargs: Any,
) -> type:
parent_fields: set[str] = set()
for base in bases:
if hasattr(base, "model_fields"):
parent_fields.update(base.model_fields)
annotations = namespace.get("__annotations__", {})
_skip_types = (classmethod, staticmethod, property)
for base in bases:
if isinstance(base, ModelMetaclass):
continue
for attr_name in getattr(base, "__annotations__", {}):
if attr_name not in annotations and attr_name not in namespace:
annotations[attr_name] = ClassVar
for attr_name, attr_value in namespace.items():
if isinstance(attr_value, property) and attr_name not in annotations:
for base in bases:
base_ann = getattr(base, "__annotations__", {})
if attr_name in base_ann:
annotations[attr_name] = ClassVar
for attr_name, attr_value in list(namespace.items()):
if attr_name in annotations or attr_name.startswith("_"):
continue
if attr_name in parent_fields:
annotations[attr_name] = Any
if isinstance(attr_value, BaseModel):
namespace[attr_name] = Field(
default_factory=lambda v=attr_value: v, exclude=True
)
continue
if callable(attr_value) or isinstance(
attr_value, (*_skip_types, FlowMethod)
):
continue
annotations[attr_name] = ClassVar[type(attr_value)]
namespace["__annotations__"] = annotations
cls = super().__new__(mcs, name, bases, namespace)
start_methods = []
@@ -820,88 +866,90 @@ class FlowMeta(type):
return cls
class Flow(Generic[T], metaclass=FlowMeta):
class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
"""Base class for all flows.
type parameter T must be either dict[str, Any] or a subclass of BaseModel."""
model_config = ConfigDict(
arbitrary_types_allowed=True,
ignored_types=(StartMethod, ListenMethod, RouterMethod),
revalidate_instances="never",
)
__hash__ = object.__hash__
_start_methods: ClassVar[list[FlowMethodName]] = []
_listeners: ClassVar[dict[FlowMethodName, SimpleFlowCondition | FlowCondition]] = {}
_routers: ClassVar[set[FlowMethodName]] = set()
_router_paths: ClassVar[dict[FlowMethodName, list[FlowMethodName]]] = {}
initial_state: type[T] | T | None = None
name: str | None = None
tracing: bool | None = None
stream: bool = False
memory: Memory | MemoryScope | MemorySlice | None = None
input_provider: InputProvider | None = None
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]:
class _FlowGeneric(cls): # type: ignore
_initial_state_t = item
initial_state: Any = Field(default=None)
name: str | None = Field(default=None)
tracing: bool | None = Field(default=None)
stream: bool = Field(default=False)
memory: Memory | MemoryScope | MemorySlice | None = Field(default=None)
input_provider: InputProvider | None = Field(default=None)
suppress_flow_events: bool = Field(default=False)
human_feedback_history: list[HumanFeedbackResult] = Field(default_factory=list)
last_human_feedback: HumanFeedbackResult | None = Field(default=None)
persistence: Any = Field(default=None, exclude=True)
max_method_calls: int = Field(default=100, exclude=True)
_methods: dict[FlowMethodName, FlowMethod[Any, Any]] = PrivateAttr(
default_factory=dict
)
_method_execution_counts: dict[FlowMethodName, int] = PrivateAttr(
default_factory=dict
)
_pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = PrivateAttr(
default_factory=dict
)
_fired_or_listeners: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_outputs: list[Any] = PrivateAttr(default_factory=list)
_state_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_or_listeners_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
_completed_methods: set[FlowMethodName] = PrivateAttr(default_factory=set)
_method_call_counts: dict[FlowMethodName, int] = PrivateAttr(default_factory=dict)
_is_execution_resuming: bool = PrivateAttr(default=False)
_event_futures: list[Future[None]] = PrivateAttr(default_factory=list)
_pending_feedback_context: PendingFeedbackContext | None = PrivateAttr(default=None)
_human_feedback_method_outputs: dict[str, Any] = PrivateAttr(default_factory=dict)
_input_history: list[InputHistoryEntry] = PrivateAttr(default_factory=list)
_state: Any = PrivateAttr(default=None)
def __class_getitem__(cls: type[Flow[T]], item: type[T]) -> type[Flow[T]]: # type: ignore[override]
class _FlowGeneric(cls): # type: ignore[valid-type,misc]
pass
_FlowGeneric.__name__ = f"{cls.__name__}[{item.__name__}]"
_FlowGeneric._initial_state_t = item
return _FlowGeneric
def __init__(
self,
persistence: FlowPersistence | None = None,
tracing: bool | None = None,
suppress_flow_events: bool = False,
max_method_calls: int = 100,
**kwargs: Any,
) -> None:
"""Initialize a new Flow instance.
def __setattr__(self, name: str, value: Any) -> None:
"""Allow arbitrary attribute assignment for backward compat with plain class."""
if name in self.model_fields or name in self.__private_attributes__:
super().__setattr__(name, value)
else:
object.__setattr__(self, name, value)
Args:
persistence: Optional persistence backend for storing flow states
tracing: Whether to enable tracing. True=always enable, False=always disable, None=check environment/user settings
suppress_flow_events: Whether to suppress flow event emissions (internal use)
max_method_calls: Maximum times a single method can be called per execution before raising RecursionError
**kwargs: Additional state values to initialize or override
"""
# Initialize basic instance attributes
self._methods: dict[FlowMethodName, FlowMethod[Any, Any]] = {}
self._method_execution_counts: dict[FlowMethodName, int] = {}
self._pending_and_listeners: dict[PendingListenerKey, set[FlowMethodName]] = {}
self._fired_or_listeners: set[FlowMethodName] = (
set()
) # Track OR listeners that already fired
self._method_outputs: list[Any] = [] # list to store all method outputs
self._state_lock = threading.Lock()
self._or_listeners_lock = threading.Lock()
self._completed_methods: set[FlowMethodName] = (
set()
) # Track completed methods for reload
self._method_call_counts: dict[FlowMethodName, int] = {}
self._max_method_calls = max_method_calls
self._persistence: FlowPersistence | None = persistence
self._is_execution_resuming: bool = False
self._event_futures: list[Future[None]] = []
def model_post_init(self, __context: Any) -> None:
self._flow_post_init()
# Human feedback storage
self.human_feedback_history: list[HumanFeedbackResult] = []
self.last_human_feedback: HumanFeedbackResult | None = None
self._pending_feedback_context: PendingFeedbackContext | None = None
# Per-method stash for real @human_feedback output (keyed by method name)
# Used to decouple routing outcome from method return value when emit is set
self._human_feedback_method_outputs: dict[str, Any] = {}
self.suppress_flow_events: bool = suppress_flow_events
def _flow_post_init(self) -> None:
"""Heavy initialization: state creation, events, memory, method registration."""
if getattr(self, "_flow_post_init_done", False):
return
object.__setattr__(self, "_flow_post_init_done", True)
# User input history (for self.ask())
self._input_history: list[InputHistoryEntry] = []
if self._state is None:
self._state = self._create_initial_state()
# Initialize state with initial values
self._state = self._create_initial_state()
self.tracing = tracing
tracing_enabled = should_enable_tracing(override=self.tracing)
set_tracing_enabled(tracing_enabled)
trace_listener = TraceCollectionListener()
trace_listener.setup_listeners(crewai_event_bus)
# Apply any additional kwargs
if kwargs:
self._initialize_state(kwargs)
if not self.suppress_flow_events:
crewai_event_bus.emit(
@@ -1385,8 +1433,8 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._pending_feedback_context = None
# Clear pending feedback from persistence
if self._persistence:
self._persistence.clear_pending_feedback(context.flow_id)
if self.persistence:
self.persistence.clear_pending_feedback(context.flow_id)
# Emit feedback received event
crewai_event_bus.emit(
@@ -1427,17 +1475,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
self._pending_feedback_context = e.context
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_pending_feedback(
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
@@ -1487,39 +1535,33 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
init_state = self.initial_state
# Handle case where initial_state is None but we have a type parameter
if init_state is None and hasattr(self, "_initial_state_t"):
state_type = self._initial_state_t
if isinstance(state_type, type):
if issubclass(state_type, FlowState):
# Create instance - FlowState auto-generates id via default_factory
instance = state_type()
# Ensure id is set - generate UUID if empty
if not getattr(instance, "id", None):
object.__setattr__(instance, "id", str(uuid4()))
return cast(T, instance)
if issubclass(state_type, BaseModel):
# Create a new type with FlowState first for proper id default
class StateWithId(FlowState, state_type): # type: ignore
pass
instance = StateWithId()
# Ensure id is set - generate UUID if empty
if not getattr(instance, "id", None):
object.__setattr__(instance, "id", str(uuid4()))
return cast(T, instance)
if state_type is dict:
return cast(T, {"id": str(uuid4())})
# Handle case where no initial state is provided
if init_state is None:
return cast(T, {"id": str(uuid4())})
# Handle case where initial_state is a type (class)
if isinstance(init_state, type):
state_class = init_state
if issubclass(state_class, FlowState):
return state_class()
return cast(T, state_class())
if issubclass(state_class, BaseModel):
model_fields = getattr(state_class, "model_fields", None)
if not model_fields or "id" not in model_fields:
@@ -1527,7 +1569,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
model_instance = state_class()
if not getattr(model_instance, "id", None):
object.__setattr__(model_instance, "id", str(uuid4()))
return model_instance
return cast(T, model_instance)
if init_state is dict:
return cast(T, {"id": str(uuid4())})
@@ -1538,32 +1580,21 @@ class Flow(Generic[T], metaclass=FlowMeta):
new_state["id"] = str(uuid4())
return cast(T, new_state)
# Handle BaseModel instance case
if isinstance(init_state, BaseModel):
model = cast(BaseModel, init_state)
if not hasattr(model, "id"):
raise ValueError("Flow state model must have an 'id' field")
# Create new instance with same values to avoid mutations
if hasattr(model, "model_dump"):
# Pydantic v2
model = init_state
if hasattr(model, "id"):
state_dict = model.model_dump()
elif hasattr(model, "dict"):
# Pydantic v1
state_dict = model.dict()
else:
# Fallback for other BaseModel implementations
state_dict = {
k: v for k, v in model.__dict__.items() if not k.startswith("_")
}
if not state_dict.get("id"):
state_dict["id"] = str(uuid4())
model_class = type(model)
return cast(T, model_class(**state_dict))
# Ensure id is set - generate UUID if empty
if not state_dict.get("id"):
state_dict["id"] = str(uuid4())
class StateWithId(FlowState, type(model)): # type: ignore
pass
# Create new instance of the same class
model_class = type(model)
return cast(T, model_class(**state_dict))
state_dict = model.model_dump()
state_dict["id"] = str(uuid4())
return cast(T, StateWithId(**state_dict))
raise TypeError(
f"Initial state must be dict or BaseModel, got {type(self.initial_state)}"
)
@@ -1576,17 +1607,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
"""
if isinstance(self._state, BaseModel):
try:
return self._state.model_copy(deep=True)
return cast(T, self._state.model_copy(deep=True))
except (TypeError, AttributeError):
try:
state_dict = self._state.model_dump()
model_class = type(self._state)
return model_class(**state_dict)
return cast(T, model_class(**state_dict))
except Exception:
return self._state.model_copy(deep=False)
return cast(T, self._state.model_copy(deep=False))
else:
try:
return copy.deepcopy(self._state)
return cast(T, copy.deepcopy(self._state))
except (TypeError, AttributeError):
return cast(T, self._state.copy())
@@ -1662,7 +1693,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
elif isinstance(self._state, BaseModel):
# For BaseModel states, preserve existing fields unless overridden
try:
model = cast(BaseModel, self._state)
model = self._state
# Get current state as dict
if hasattr(model, "model_dump"):
current_state = model.model_dump()
@@ -1713,7 +1744,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
self._state.update(stored_state)
elif isinstance(self._state, BaseModel):
# For BaseModel states, create new instance with stored values
model = cast(BaseModel, self._state)
model = self._state
if hasattr(model, "model_validate"):
# Pydantic v2
self._state = cast(T, type(model).model_validate(stored_state))
@@ -1938,7 +1969,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
try:
# Reset flow state for fresh execution unless restoring from persistence
is_restoring = inputs and "id" in inputs and self._persistence is not None
is_restoring = inputs and "id" in inputs and self.persistence is not None
if not is_restoring:
# Clear completed methods and outputs for a fresh start
self._completed_methods.clear()
@@ -1964,9 +1995,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
setattr(self._state, "id", inputs["id"]) # noqa: B010
# If persistence is enabled, attempt to restore the stored state using the provided id.
if "id" in inputs and self._persistence is not None:
if "id" in inputs and self.persistence is not None:
restore_uuid = inputs["id"]
stored_state = self._persistence.load_state(restore_uuid)
stored_state = self.persistence.load_state(restore_uuid)
if stored_state:
self._log_flow_event(
f"Loading flow state from memory for UUID: {restore_uuid}"
@@ -2036,17 +2067,17 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
# Auto-save pending feedback (create default persistence if needed)
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
state_data = (
self._state
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_pending_feedback(
self.persistence.save_pending_feedback(
flow_uuid=e.context.flow_id,
context=e.context,
state_data=state_data,
@@ -2332,10 +2363,10 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(e, HumanFeedbackPending):
e.context.method_name = method_name
if self._persistence is None:
if self.persistence is None:
from crewai.flow.persistence import SQLiteFlowPersistence
self._persistence = SQLiteFlowPersistence()
self.persistence = SQLiteFlowPersistence()
# Emit paused event (not failed)
if not self.suppress_flow_events:
@@ -2696,9 +2727,9 @@ class Flow(Generic[T], metaclass=FlowMeta):
- Catches and logs any exceptions during execution, preventing individual listener failures from breaking the entire flow
"""
count = self._method_call_counts.get(listener_name, 0) + 1
if count > self._max_method_calls:
if count > self.max_method_calls:
raise RecursionError(
f"Method '{listener_name}' has been called {self._max_method_calls} times in "
f"Method '{listener_name}' has been called {self.max_method_calls} times in "
f"this flow execution, which indicates an infinite loop. "
f"This commonly happens when a @listen label matches the "
f"method's own name."
@@ -2805,7 +2836,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
This is best-effort: if persistence is not configured, this is a no-op.
"""
if self._persistence is None:
if self.persistence is None:
return
try:
state_data = (
@@ -2813,7 +2844,7 @@ class Flow(Generic[T], metaclass=FlowMeta):
if isinstance(self._state, dict)
else self._state.model_dump()
)
self._persistence.save_state(
self.persistence.save_state(
flow_uuid=self.flow_id,
method_name="_ask_checkpoint",
state_data=state_data,

View File

@@ -3,12 +3,15 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, ConfigDict
if TYPE_CHECKING:
from crewai.rag.types import SearchResult
class BaseKnowledgeStorage(ABC):
class BaseKnowledgeStorage(BaseModel, ABC):
model_config = ConfigDict(arbitrary_types_allowed=True)
"""Abstract base class for knowledge storage implementations."""
@abstractmethod

View File

@@ -3,6 +3,9 @@ import traceback
from typing import Any, cast
import warnings
from pydantic import Field, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
from crewai.rag.chromadb.config import ChromaDBConfig
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
search efficiency.
"""
def __init__(
self,
embedder: ProviderSpec
collection_name: str | None = None
embedder: (
ProviderSpec
| BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]]
| None = None,
collection_name: str | None = None,
) -> None:
self.collection_name = collection_name
self._client: BaseClient | None = None
| None
) = Field(default=None, exclude=True)
_client: BaseClient | None = PrivateAttr(default=None)
@model_validator(mode="after")
def _init_client(self) -> Self:
warnings.filterwarnings(
"ignore",
message=r".*'model_fields'.*is deprecated.*",
module=r"^chromadb(\.|$)",
)
if embedder:
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
if self.embedder:
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
config = ChromaDBConfig(
embedding_function=cast(
ChromaEmbeddingFunctionWrapper, embedding_function
)
)
self._client = create_client(config)
return self
def _get_client(self) -> BaseClient:
"""Get the appropriate client - instance-specific or global."""

View File

@@ -22,7 +22,6 @@ from pydantic import (
UUID4,
BaseModel,
Field,
InstanceOf,
PrivateAttr,
field_validator,
model_validator,
@@ -204,7 +203,7 @@ class LiteAgent(FlowTrackable, BaseModel):
role: str = Field(description="Role of the agent")
goal: str = Field(description="Goal of the agent")
backstory: str = Field(description="Backstory of the agent")
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
llm: str | BaseLLM | Any | None = Field(
default=None, description="Language model that will run the agent"
)
tools: list[BaseTool] = Field(

View File

@@ -20,8 +20,7 @@ from typing import (
)
from dotenv import load_dotenv
import httpx
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, model_validator
from typing_extensions import Self
from crewai.events.event_bus import crewai_event_bus
@@ -37,7 +36,12 @@ from crewai.events.types.tool_usage_events import (
ToolUsageFinishedEvent,
ToolUsageStartedEvent,
)
from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context
from crewai.llms.base_llm import (
BaseLLM,
JsonResponseFormat,
get_current_call_id,
llm_call_context,
)
from crewai.llms.constants import (
ANTHROPIC_MODELS,
AZURE_MODELS,
@@ -63,8 +67,6 @@ except ImportError:
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
@@ -342,6 +344,27 @@ class AccumulatedToolArgs(BaseModel):
class LLM(BaseLLM):
completion_cost: float | None = None
timeout: float | int | None = None
top_p: float | None = None
n: int | None = None
max_completion_tokens: int | None = None
max_tokens: int | float | None = None
presence_penalty: float | None = None
frequency_penalty: float | None = None
logit_bias: dict[int, float] | None = None
response_format: JsonResponseFormat | type[BaseModel] | None = None
seed: int | None = None
logprobs: int | None = None
top_logprobs: int | None = None
api_base: str | None = None
api_version: str | None = None
callbacks: list[Any] | None = None
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
stream: bool = False
interceptor: Any = None
thinking: Any = None
context_window_size: int = 0
is_anthropic: bool = False
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
"""Factory method that routes to native SDK or falls back to LiteLLM.
@@ -436,10 +459,7 @@ class LLM(BaseLLM):
logger.error(error_msg)
raise ImportError(error_msg) from None
instance = object.__new__(cls)
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
instance.is_litellm = True
return instance
return object.__new__(cls)
@classmethod
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
@@ -624,89 +644,23 @@ class LLM(BaseLLM):
return None
def __init__(
self,
model: str,
timeout: float | int | None = None,
temperature: float | None = None,
top_p: float | None = None,
n: int | None = None,
stop: str | list[str] | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | float | None = None,
presence_penalty: float | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[int, float] | None = None,
response_format: type[BaseModel] | None = None,
seed: int | None = None,
logprobs: int | None = None,
top_logprobs: int | None = None,
base_url: str | None = None,
api_base: str | None = None,
api_version: str | None = None,
api_key: str | None = None,
callbacks: list[Any] | None = None,
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
prefer_upload: bool = False,
**kwargs: Any,
) -> None:
"""Initialize LLM instance.
@model_validator(mode="before")
@classmethod
def _validate_llm_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
model = data.get("model", "")
data["is_anthropic"] = cls._is_anthropic_model(model)
return data
Note: This __init__ method is only called for fallback instances.
Native provider instances handle their own initialization in their respective classes.
"""
super().__init__(
model=model,
temperature=temperature,
api_key=api_key,
base_url=base_url,
timeout=timeout,
**kwargs,
)
self.model = model
self.timeout = timeout
self.temperature = temperature
self.top_p = top_p
self.n = n
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.base_url = base_url
self.api_base = api_base
self.api_version = api_version
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.reasoning_effort = reasoning_effort
self.prefer_upload = prefer_upload
self.additional_params = {
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
}
self.is_anthropic = self._is_anthropic_model(model)
self.stream = stream
self.interceptor = interceptor
litellm.drop_params = True
# Normalize self.stop to always be a list[str]
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
self.set_callbacks(callbacks or [])
self.set_env_callbacks()
@model_validator(mode="after")
def _init_litellm(self) -> LLM:
self.is_litellm = True
if LITELLM_AVAILABLE:
litellm.drop_params = True
self.set_callbacks(self.callbacks or [])
self.set_env_callbacks()
return self
@staticmethod
def _is_anthropic_model(model: str) -> bool:
@@ -1016,21 +970,25 @@ class LLM(BaseLLM):
)
result = instructor_instance.to_pydantic()
structured_response = result.model_dump_json()
usage_dict = self._usage_to_dict(usage_info)
self._handle_emit_call_events(
response=structured_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_dict,
)
return structured_response
usage_dict = self._usage_to_dict(usage_info)
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_dict,
)
return full_response
@@ -1040,12 +998,14 @@ class LLM(BaseLLM):
return tool_result
# --- 10) Emit completion event and return response
usage_dict = self._usage_to_dict(usage_info)
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_dict,
)
return full_response
@@ -1067,6 +1027,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=self._usage_to_dict(usage_info),
)
return full_response
@@ -1218,6 +1179,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=None,
)
return structured_response
@@ -1248,6 +1210,8 @@ class LLM(BaseLLM):
raise LLMContextLengthExceededError(error_msg) from e
raise
response_usage = self._usage_to_dict(getattr(response, "usage", None))
# --- 2) Handle structured output response (when response_model is provided)
if response_model is not None:
# When using instructor/response_model, litellm returns a Pydantic model instance
@@ -1259,6 +1223,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return structured_response
@@ -1290,6 +1255,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return text_response
@@ -1313,6 +1279,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return text_response
@@ -1362,6 +1329,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=None,
)
return structured_response
@@ -1388,6 +1356,8 @@ class LLM(BaseLLM):
raise LLMContextLengthExceededError(error_msg) from e
raise
response_usage = self._usage_to_dict(getattr(response, "usage", None))
if response_model is not None:
if isinstance(response, BaseModel):
structured_response = response.model_dump_json()
@@ -1397,6 +1367,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return structured_response
@@ -1426,6 +1397,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return text_response
@@ -1448,6 +1420,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=response_usage,
)
return text_response
@@ -1594,12 +1567,14 @@ class LLM(BaseLLM):
if result is not None:
return result
usage_dict = self._usage_to_dict(usage_info)
self._handle_emit_call_events(
response=full_response,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params.get("messages"),
usage=usage_dict,
)
return full_response
@@ -1621,6 +1596,7 @@ class LLM(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("messages"),
usage=self._usage_to_dict(usage_info),
)
return full_response
raise
@@ -2007,6 +1983,19 @@ class LLM(BaseLLM):
)
raise
@staticmethod
def _usage_to_dict(usage: Any) -> dict[str, Any] | None:
if usage is None:
return None
if isinstance(usage, dict):
return usage
if hasattr(usage, "model_dump"):
result: dict[str, Any] = usage.model_dump()
return result
if hasattr(usage, "__dict__"):
return {k: v for k, v in vars(usage).items() if not k.startswith("_")}
return None
def _handle_emit_call_events(
self,
response: Any,
@@ -2014,6 +2003,7 @@ class LLM(BaseLLM):
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[LLMMessage] | None = None,
usage: dict[str, Any] | None = None,
) -> None:
"""Handle the events for the LLM call.
@@ -2023,6 +2013,7 @@ class LLM(BaseLLM):
from_task: Optional task object
from_agent: Optional agent object
messages: Optional messages object
usage: Optional token usage data
"""
crewai_event_bus.emit(
self,
@@ -2034,6 +2025,7 @@ class LLM(BaseLLM):
from_agent=from_agent,
model=self.model,
call_id=get_current_call_id(),
usage=usage,
),
)
@@ -2442,7 +2434,7 @@ class LLM(BaseLLM):
**filtered_params,
)
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM:
"""Create a deep copy of the LLM instance."""
import copy

View File

@@ -14,10 +14,18 @@ from datetime import datetime
import json
import logging
import re
from typing import TYPE_CHECKING, Any, Final
from typing import TYPE_CHECKING, Any, Final, Literal
import uuid
from pydantic import BaseModel
from pydantic import (
AliasChoices,
BaseModel,
ConfigDict,
Field,
PrivateAttr,
model_validator,
)
from typing_extensions import TypedDict
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.llm_events import (
@@ -51,6 +59,12 @@ if TYPE_CHECKING:
from crewai.utilities.types import LLMMessage
class JsonResponseFormat(TypedDict):
"""Response format requesting raw JSON output (e.g. ``{"type": "json_object"}``)."""
type: Literal["json_object"]
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
@@ -82,7 +96,7 @@ def get_current_call_id() -> str:
return call_id
class BaseLLM(ABC):
class BaseLLM(BaseModel, ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
@@ -101,56 +115,100 @@ class BaseLLM(ABC):
additional_params: Additional provider-specific parameters.
"""
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
model: str
temperature: float | None = None
api_key: str | None = None
base_url: str | None = None
provider: str = Field(default="openai")
prefer_upload: bool = False
is_litellm: bool = False
stop: list[str] = Field(
default_factory=list,
validation_alias=AliasChoices("stop", "stop_sequences"),
)
additional_params: dict[str, Any] = Field(default_factory=dict)
def __init__(
self,
model: str,
temperature: float | None = None,
api_key: str | None = None,
base_url: str | None = None,
provider: str | None = None,
prefer_upload: bool = False,
**kwargs: Any,
) -> None:
"""Initialize the BaseLLM with default attributes.
def __setattr__(self, name: str, value: Any) -> None:
if name in ("stop", "stop_sequences"):
if value is None:
value = []
elif isinstance(value, str):
value = [value]
elif not isinstance(value, list):
value = list(value)
name = "stop"
try:
super().__setattr__(name, value)
except ValueError:
if name in self.model_fields:
raise # Re-raise validation errors on declared fields
# Fallback for attributes not declared as fields (e.g. mock patching)
object.__setattr__(self, name, value)
except AttributeError:
object.__setattr__(self, name, value)
Args:
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: Optional list of stop sequences for generation.
prefer_upload: Whether to prefer file upload over inline base64.
**kwargs: Additional provider-specific parameters.
def __delattr__(self, name: str) -> None:
try:
super().__delattr__(name)
except AttributeError:
object.__delattr__(self, name)
@property
def stop_sequences(self) -> list[str]:
"""Alias for ``stop`` — kept for backward compatibility with provider APIs.
Writes are handled by ``__setattr__``, which normalizes and redirects
``stop_sequences`` assignments to the ``stop`` field.
"""
if not model:
raise ValueError("Model name is required and cannot be empty")
return self.stop
self.model = model
self.temperature = temperature
self.api_key = api_key
self.base_url = base_url
self.prefer_upload = prefer_upload
# Store additional parameters for provider-specific use
self.additional_params = kwargs
self._provider = provider or "openai"
stop = kwargs.pop("stop", None)
if stop is None:
self.stop: list[str] = []
elif isinstance(stop, str):
self.stop = [stop]
elif isinstance(stop, list):
self.stop = stop
else:
self.stop = []
self._token_usage = {
_token_usage: dict[str, int] = PrivateAttr(
default_factory=lambda: {
"total_tokens": 0,
"prompt_tokens": 0,
"completion_tokens": 0,
"successful_requests": 0,
"cached_prompt_tokens": 0,
}
)
@model_validator(mode="before")
@classmethod
def _validate_init_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if not data.get("model"):
raise ValueError("Model name is required and cannot be empty")
# Normalize stop: accept str, list, or None; also accept stop_sequences alias
stop_seqs = data.pop("stop_sequences", None)
stop = stop_seqs if stop_seqs is not None else data.get("stop")
if stop is None:
data["stop"] = []
elif isinstance(stop, str):
data["stop"] = [stop]
elif isinstance(stop, list):
data["stop"] = stop
else:
data["stop"] = list(stop)
# Default provider
if not data.get("provider"):
data["provider"] = "openai"
# Collect unknown kwargs into additional_params
known_fields = set(cls.model_fields.keys())
extras = {k: v for k, v in data.items() if k not in known_fields}
for k in extras:
data.pop(k)
existing = data.get("additional_params") or {}
existing.update(extras)
data["additional_params"] = existing
return data
def to_config_dict(self) -> dict[str, Any]:
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
@@ -174,16 +232,6 @@ class BaseLLM(ABC):
return config
@property
def provider(self) -> str:
"""Get the provider of the LLM."""
return self._provider
@provider.setter
def provider(self, value: str) -> None:
"""Set the provider of the LLM."""
self._provider = value
@abstractmethod
def call(
self,
@@ -412,6 +460,7 @@ class BaseLLM(ABC):
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[LLMMessage] | None = None,
usage: dict[str, Any] | None = None,
) -> None:
"""Emit LLM call completed event."""
from crewai.utilities.serialization import to_serializable
@@ -426,6 +475,7 @@ class BaseLLM(ABC):
from_agent=from_agent,
model=self.model,
call_id=get_current_call_id(),
usage=usage,
),
)

View File

@@ -3,12 +3,13 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
from typing import Any, Final, Literal, TypeGuard, cast
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr, model_validator
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, llm_call_context
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -17,9 +18,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from anthropic import Anthropic, AsyncAnthropic, transform_schema
from anthropic.types import (
@@ -150,60 +148,47 @@ class AnthropicCompletion(BaseLLM):
offering native tool use, streaming support, and proper message formatting.
"""
def __init__(
self,
model: str = "claude-3-5-sonnet-20241022",
api_key: str | None = None,
base_url: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
max_tokens: int = 4096, # Required for Anthropic
top_p: float | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None,
response_format: type[BaseModel] | None = None,
tool_search: AnthropicToolSearchConfig | bool | None = None,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
model: str = "claude-3-5-sonnet-20241022"
timeout: float | None = None
max_retries: int = 2
max_tokens: int = 4096
top_p: float | None = None
stream: bool = False
client_params: dict[str, Any] | None = None
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
thinking: AnthropicThinkingConfig | None = None
response_format: JsonResponseFormat | type[BaseModel] | None = None
tool_search: AnthropicToolSearchConfig | None = None
is_claude_3: bool = False
supports_tools: bool = True
Args:
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
base_url: Custom base URL for Anthropic API
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-1)
max_tokens: Maximum tokens in response (required for Anthropic)
top_p: Nucleus sampling parameter
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
stream: Enable streaming responses
client_params: Additional parameters for the Anthropic client
interceptor: HTTP interceptor for modifying requests/responses at transport level.
response_format: Pydantic model for structured output. When provided, responses
will be validated against this model schema.
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
"bm25". When enabled, tools are automatically marked with defer_loading=True
and a tool search tool is injected into the tools list.
**kwargs: Additional parameters
"""
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
)
_client: Any = PrivateAttr(default=None)
_async_client: Any = PrivateAttr(default=None)
_previous_thinking_blocks: list[Any] = PrivateAttr(default_factory=list)
# Client params
self.interceptor = interceptor
self.client_params = client_params
self.base_url = base_url
self.timeout = timeout
self.max_retries = max_retries
@model_validator(mode="before")
@classmethod
def _normalize_anthropic_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
# Anthropic uses stop_sequences; normalize from stop kwarg
popped = data.pop("stop_sequences", None)
seqs = popped if popped is not None else (data.get("stop") or [])
if isinstance(seqs, str):
seqs = [seqs]
data["stop"] = seqs
data["is_claude_3"] = "claude-3" in data.get("model", "").lower()
# Normalize tool_search
ts = data.get("tool_search")
if ts is True:
data["tool_search"] = AnthropicToolSearchConfig()
elif ts is not None and not isinstance(ts, AnthropicToolSearchConfig):
data["tool_search"] = None
return data
self.client = Anthropic(**self._get_client_params())
@model_validator(mode="after")
def _init_clients(self) -> AnthropicCompletion:
self._client = Anthropic(**self._get_client_params())
async_client_params = self._get_client_params()
if self.interceptor:
@@ -211,51 +196,8 @@ class AnthropicCompletion(BaseLLM):
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_params["http_client"] = async_http_client
self.async_client = AsyncAnthropic(**async_client_params)
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = []
self.response_format = response_format
# Tool search config
self.tool_search: AnthropicToolSearchConfig | None
if tool_search is True:
self.tool_search = AnthropicToolSearchConfig()
elif isinstance(tool_search, AnthropicToolSearchConfig):
self.tool_search = tool_search
else:
self.tool_search = None
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""
return self.stop_sequences
@stop.setter
def stop(self, value: list[str] | str | None) -> None:
"""Set stop sequences.
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
are properly sent to the Anthropic API.
Args:
value: Stop sequences as a list, single string, or None
"""
if value is None:
self.stop_sequences = []
elif isinstance(value, str):
self.stop_sequences = [value]
elif isinstance(value, list):
self.stop_sequences = value
else:
self.stop_sequences = []
self._async_client = AsyncAnthropic(**async_client_params)
return self
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Anthropic-specific fields."""
@@ -751,11 +693,11 @@ class AnthropicCompletion(BaseLLM):
)
elif isinstance(content, list):
formatted_messages.append({"role": "assistant", "content": content})
elif self.thinking and self.previous_thinking_blocks:
elif self.thinking and self._previous_thinking_blocks:
structured_content = cast(
list[dict[str, Any]],
[
*self.previous_thinking_blocks,
*self._previous_thinking_blocks,
{"type": "text", "text": content if content else ""},
],
)
@@ -809,7 +751,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
response_model: JsonResponseFormat | type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming message completion."""
uses_file_api = _contains_file_id_reference(params.get("messages", []))
@@ -843,11 +785,11 @@ class AnthropicCompletion(BaseLLM):
try:
if betas:
params["betas"] = betas
response = self.client.beta.messages.create(
response = self._client.beta.messages.create(
**params, extra_body=extra_body
)
else:
response = self.client.messages.create(**params)
response = self._client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -869,6 +811,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
else:
@@ -884,6 +827,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
@@ -906,6 +850,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return list(tool_uses)
@@ -928,7 +873,7 @@ class AnthropicCompletion(BaseLLM):
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
self._previous_thinking_blocks = thinking_blocks
content = self._apply_stop_words(content)
self._emit_call_completed_event(
@@ -937,6 +882,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
if usage.get("total_tokens", 0) > 0:
@@ -952,7 +898,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
response_model: JsonResponseFormat | type[BaseModel] | None = None,
) -> str | Any:
"""Handle streaming message completion."""
betas: list[str] = []
@@ -991,9 +937,9 @@ class AnthropicCompletion(BaseLLM):
current_tool_calls: dict[int, dict[str, Any]] = {}
stream_context = (
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
if betas
else self.client.messages.stream(**stream_params)
else self._client.messages.stream(**stream_params)
)
with stream_context as stream:
response_id = None
@@ -1072,7 +1018,7 @@ class AnthropicCompletion(BaseLLM):
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
self._previous_thinking_blocks = thinking_blocks
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
@@ -1086,6 +1032,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
for block in final_message.content:
@@ -1100,6 +1047,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
@@ -1129,6 +1077,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return self._invoke_after_llm_call_hooks(
@@ -1269,7 +1218,7 @@ class AnthropicCompletion(BaseLLM):
try:
# Send tool results back to Claude for final response
final_response: Message = self.client.messages.create(**follow_up_params)
final_response: Message = self._client.messages.create(**follow_up_params)
# Track token usage for follow-up call
follow_up_usage = self._extract_anthropic_token_usage(final_response)
@@ -1288,7 +1237,7 @@ class AnthropicCompletion(BaseLLM):
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
self._previous_thinking_blocks = thinking_blocks
final_content = self._apply_stop_words(final_content)
@@ -1299,6 +1248,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=follow_up_params["messages"],
usage=follow_up_usage,
)
# Log combined token usage
@@ -1330,7 +1280,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
response_model: JsonResponseFormat | type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming async message completion."""
uses_file_api = _contains_file_id_reference(params.get("messages", []))
@@ -1364,11 +1314,11 @@ class AnthropicCompletion(BaseLLM):
try:
if betas:
params["betas"] = betas
response = await self.async_client.beta.messages.create(
response = await self._async_client.beta.messages.create(
**params, extra_body=extra_body
)
else:
response = await self.async_client.messages.create(**params)
response = await self._async_client.messages.create(**params)
except Exception as e:
if is_context_length_exceeded(e):
@@ -1390,6 +1340,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
else:
@@ -1405,6 +1356,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
@@ -1425,6 +1377,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return list(tool_uses)
@@ -1448,6 +1401,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
if usage.get("total_tokens", 0) > 0:
@@ -1461,7 +1415,7 @@ class AnthropicCompletion(BaseLLM):
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
response_model: JsonResponseFormat | type[BaseModel] | None = None,
) -> str | Any:
"""Handle async streaming message completion."""
betas: list[str] = []
@@ -1498,11 +1452,11 @@ class AnthropicCompletion(BaseLLM):
current_tool_calls: dict[int, dict[str, Any]] = {}
stream_context = (
self.async_client.beta.messages.stream(
self._async_client.beta.messages.stream(
**stream_params, extra_body=extra_body
)
if betas
else self.async_client.messages.stream(**stream_params)
else self._async_client.messages.stream(**stream_params)
)
async with stream_context as stream:
response_id = None
@@ -1585,6 +1539,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
for block in final_message.content:
@@ -1599,6 +1554,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
@@ -1627,6 +1583,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return full_response
@@ -1664,7 +1621,7 @@ class AnthropicCompletion(BaseLLM):
]
try:
final_response: Message = await self.async_client.messages.create(
final_response: Message = await self._async_client.messages.create(
**follow_up_params
)
@@ -1685,6 +1642,7 @@ class AnthropicCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=follow_up_params["messages"],
usage=follow_up_usage,
)
total_usage = {
@@ -1786,8 +1744,8 @@ class AnthropicCompletion(BaseLLM):
from crewai_files.uploaders.anthropic import AnthropicFileUploader
return AnthropicFileUploader(
client=self.client,
async_client=self.async_client,
client=self._client,
async_client=self._async_client,
)
except ImportError:
return None

View File

@@ -3,11 +3,13 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict
from typing import Any, TypedDict
from urllib.parse import urlparse
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.llms.hooks.base import BaseInterceptor
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -16,10 +18,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from azure.ai.inference import (
ChatCompletionsClient,
@@ -76,109 +74,84 @@ class AzureCompletion(BaseLLM):
offering native function calling, streaming support, and proper Azure authentication.
"""
def __init__(
self,
model: str,
api_key: str | None = None,
endpoint: str | None = None,
api_version: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
stop: list[str] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
):
"""Initialize Azure AI Inference chat completion client.
endpoint: str | None = None
api_version: str | None = None
timeout: float | None = None
max_retries: int = 2
top_p: float | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
max_tokens: int | None = None
stream: bool = False
interceptor: BaseInterceptor[Any, Any] | None = None
response_format: type[BaseModel] | None = None
is_openai_model: bool = False
is_azure_openai_endpoint: bool = False
Args:
model: Azure deployment name or model name
api_key: Azure API key (defaults to AZURE_API_KEY env var)
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
timeout: Request timeout in seconds
max_retries: Maximum number of retries
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
frequency_penalty: Frequency penalty (-2 to 2)
presence_penalty: Presence penalty (-2 to 2)
max_tokens: Maximum tokens in response
stop: Stop sequences
stream: Enable streaming responses
interceptor: HTTP interceptor (not yet supported for Azure).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
Only works with OpenAI models deployed on Azure.
**kwargs: Additional parameters
"""
if interceptor is not None:
_client: Any = PrivateAttr(default=None)
_async_client: Any = PrivateAttr(default=None)
@model_validator(mode="before")
@classmethod
def _normalize_azure_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if data.get("interceptor") is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
super().__init__(
model=model, temperature=temperature, stop=stop or [], **kwargs
)
self.api_key = api_key or os.getenv("AZURE_API_KEY")
self.endpoint = (
endpoint
# Resolve env vars
data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY")
data["endpoint"] = (
data.get("endpoint")
or os.getenv("AZURE_ENDPOINT")
or os.getenv("AZURE_OPENAI_ENDPOINT")
or os.getenv("AZURE_API_BASE")
)
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
self.timeout = timeout
self.max_retries = max_retries
data["api_version"] = (
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
)
if not self.api_key:
if not data["api_key"]:
raise ValueError(
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
)
if not self.endpoint:
if not data["endpoint"]:
raise ValueError(
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
)
# Validate and potentially fix Azure OpenAI endpoint URL
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
model = data.get("model", "")
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
data["endpoint"], model
)
data["is_openai_model"] = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
parsed = urlparse(data["endpoint"])
hostname = parsed.hostname or ""
data["is_azure_openai_endpoint"] = (
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
) and "/openai/deployments/" in data["endpoint"]
return data
# Build client kwargs
client_kwargs = {
@model_validator(mode="after")
def _init_clients(self) -> AzureCompletion:
if not self.api_key:
raise ValueError("Azure API key is required.")
client_kwargs: dict[str, Any] = {
"endpoint": self.endpoint,
"credential": AzureKeyCredential(self.api_key),
}
# Add api_version if specified (primarily for Azure OpenAI endpoints)
if self.api_version:
client_kwargs["api_version"] = self.api_version
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.stream = stream
self.response_format = response_format
self.is_openai_model = any(
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
)
self.is_azure_openai_endpoint = (
"openai.azure.com" in self.endpoint
and "/openai/deployments/" in self.endpoint
)
self._client = ChatCompletionsClient(**client_kwargs)
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
return self
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Azure-specific fields."""
@@ -215,7 +188,11 @@ class AzureCompletion(BaseLLM):
Returns:
Validated and potentially corrected endpoint URL
"""
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
ep_host = urlparse(endpoint).hostname or ""
is_azure_openai = ep_host == "openai.azure.com" or ep_host.endswith(
".openai.azure.com"
)
if is_azure_openai and "/openai/deployments/" not in endpoint:
endpoint = endpoint.rstrip("/")
if not endpoint.endswith("/openai/deployments"):
@@ -592,6 +569,7 @@ class AzureCompletion(BaseLLM):
params: AzureCompletionParams,
from_task: Any | None = None,
from_agent: Any | None = None,
usage: dict[str, Any] | None = None,
) -> BaseModel:
"""Validate content against response model and emit completion event.
@@ -617,6 +595,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_data
@@ -666,6 +645,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return list(message.tool_calls)
@@ -703,6 +683,7 @@ class AzureCompletion(BaseLLM):
params=params,
from_task=from_task,
from_agent=from_agent,
usage=usage,
)
content = self._apply_stop_words(content)
@@ -714,6 +695,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return self._invoke_after_llm_call_hooks(
@@ -731,7 +713,7 @@ class AzureCompletion(BaseLLM):
"""Handle non-streaming chat completion."""
try:
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
response: ChatCompletions = self._client.complete(**params)
return self._process_completion_response(
response=response,
params=params,
@@ -817,7 +799,7 @@ class AzureCompletion(BaseLLM):
self,
full_response: str,
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, Any] | None,
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -829,7 +811,7 @@ class AzureCompletion(BaseLLM):
Args:
full_response: The complete streamed response content
tool_calls: Dictionary of tool calls accumulated during streaming
usage_data: Token usage data from the stream
usage_data: Token usage data from the stream, or None if unavailable
params: Completion parameters containing messages
available_functions: Available functions for tool calling
from_task: Task that initiated the call
@@ -839,7 +821,8 @@ class AzureCompletion(BaseLLM):
Returns:
Final response content after processing, or structured output
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
# Handle structured output validation
if response_model and self.is_openai_model:
@@ -849,6 +832,7 @@ class AzureCompletion(BaseLLM):
params=params,
from_task=from_task,
from_agent=from_agent,
usage=usage_data,
)
# If there are tool_calls but no available_functions, return them
@@ -871,6 +855,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return formatted_tool_calls
@@ -907,6 +892,7 @@ class AzureCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return self._invoke_after_llm_call_hooks(
@@ -925,8 +911,8 @@ class AzureCompletion(BaseLLM):
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
for update in self.client.complete(**params): # type: ignore[arg-type]
usage_data: dict[str, Any] | None = None
for update in self._client.complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.usage:
usage = update.usage
@@ -967,7 +953,7 @@ class AzureCompletion(BaseLLM):
"""Handle non-streaming chat completion asynchronously."""
try:
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
response: ChatCompletions = await self._async_client.complete(**params)
return self._process_completion_response(
response=response,
params=params,
@@ -991,10 +977,10 @@ class AzureCompletion(BaseLLM):
full_response = ""
tool_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
async for update in stream: # type: ignore[union-attr]
stream = await self._async_client.complete(**params)
async for update in stream:
if isinstance(update, StreamingChatCompletionsUpdate):
if hasattr(update, "usage") and update.usage:
usage = update.usage
@@ -1110,8 +1096,8 @@ class AzureCompletion(BaseLLM):
This ensures proper cleanup of the underlying aiohttp session
to avoid unclosed connector warnings.
"""
if hasattr(self.async_client, "close"):
await self.async_client.close()
if hasattr(self._async_client, "close"):
await self._async_client.close()
async def __aenter__(self) -> Self:
"""Async context manager entry."""

View File

@@ -7,7 +7,7 @@ import logging
import os
from typing import TYPE_CHECKING, Any, TypedDict, cast
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr, model_validator
from typing_extensions import Required
from crewai.events.types.llm_events import LLMCallType
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
ToolTypeDef,
)
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.base import BaseInterceptor
try:
@@ -228,129 +228,97 @@ class BedrockCompletion(BaseLLM):
- Model-specific conversation format handling (e.g., Cohere requirements)
"""
def __init__(
self,
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
aws_access_key_id: str | None = None,
aws_secret_access_key: str | None = None,
aws_session_token: str | None = None,
region_name: str | None = None,
temperature: float | None = None,
max_tokens: int | None = None,
top_p: float | None = None,
top_k: int | None = None,
stop_sequences: Sequence[str] | None = None,
stream: bool = False,
guardrail_config: dict[str, Any] | None = None,
additional_model_request_fields: dict[str, Any] | None = None,
additional_model_response_field_paths: list[str] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
response_format: type[BaseModel] | None = None,
**kwargs: Any,
) -> None:
"""Initialize AWS Bedrock completion client.
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
aws_access_key_id: str | None = None
aws_secret_access_key: str | None = None
aws_session_token: str | None = None
region_name: str | None = None
max_tokens: int | None = None
top_p: float | None = None
top_k: int | None = None
stream: bool = False
guardrail_config: dict[str, Any] | None = None
additional_model_request_fields: dict[str, Any] | None = None
additional_model_response_field_paths: list[str] | None = None
interceptor: BaseInterceptor[Any, Any] | None = None
response_format: type[BaseModel] | None = None
is_claude_model: bool = False
supports_tools: bool = True
supports_streaming: bool = True
model_id: str = ""
Args:
model: The Bedrock model ID to use
aws_access_key_id: AWS access key (defaults to environment variable)
aws_secret_access_key: AWS secret key (defaults to environment variable)
aws_session_token: AWS session token for temporary credentials
region_name: AWS region name
temperature: Sampling temperature for response generation
max_tokens: Maximum tokens to generate
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter (Claude models only)
stop_sequences: List of sequences that stop generation
stream: Whether to use streaming responses
guardrail_config: Guardrail configuration for content filtering
additional_model_request_fields: Model-specific request parameters
additional_model_response_field_paths: Custom response field paths
interceptor: HTTP interceptor (not yet supported for Bedrock).
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
**kwargs: Additional parameters
"""
if interceptor is not None:
_client: Any = PrivateAttr(default=None)
_async_exit_stack: Any = PrivateAttr(default=None)
_async_client_initialized: bool = PrivateAttr(default=False)
_async_client: Any = PrivateAttr(default=None)
@model_validator(mode="before")
@classmethod
def _normalize_bedrock_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if data.get("interceptor") is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
# Extract provider from kwargs to avoid duplicate argument
kwargs.pop("provider", None)
# Force provider to bedrock
data.pop("provider", None)
data["provider"] = "bedrock"
super().__init__(
model=model,
temperature=temperature,
stop=stop_sequences or [],
provider="bedrock",
**kwargs,
# Normalize stop_sequences from stop kwarg
popped = data.pop("stop_sequences", None)
seqs = popped if popped is not None else (data.get("stop") or [])
if isinstance(seqs, str):
seqs = [seqs]
elif isinstance(seqs, Sequence) and not isinstance(seqs, list):
seqs = list(seqs)
data["stop"] = seqs
# Resolve env vars
data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv(
"AWS_ACCESS_KEY_ID"
)
# Configure client with timeouts and retries following AWS best practices
config = Config(
read_timeout=300,
retries={
"max_attempts": 3,
"mode": "adaptive",
},
tcp_keepalive=True,
data["aws_secret_access_key"] = data.get("aws_secret_access_key") or os.getenv(
"AWS_SECRET_ACCESS_KEY"
)
self.region_name = (
region_name
data["aws_session_token"] = data.get("aws_session_token") or os.getenv(
"AWS_SESSION_TOKEN"
)
data["region_name"] = (
data.get("region_name")
or os.getenv("AWS_DEFAULT_REGION")
or os.getenv("AWS_REGION_NAME")
or "us-east-1"
)
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
"AWS_SECRET_ACCESS_KEY"
)
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
model = data.get("model", "anthropic.claude-3-5-sonnet-20241022-v2:0")
data["is_claude_model"] = "claude" in model.lower()
data["model_id"] = model
return data
# Initialize Bedrock client with proper configuration
@model_validator(mode="after")
def _init_clients(self) -> BedrockCompletion:
config = Config(
read_timeout=300,
retries={"max_attempts": 3, "mode": "adaptive"},
tcp_keepalive=True,
)
session = Session(
aws_access_key_id=self.aws_access_key_id,
aws_secret_access_key=self.aws_secret_access_key,
aws_session_token=self.aws_session_token,
region_name=self.region_name,
)
self.client = session.client("bedrock-runtime", config=config)
self._client = session.client("bedrock-runtime", config=config)
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
self._async_client_initialized = False
# Store completion parameters
self.max_tokens = max_tokens
self.top_p = top_p
self.top_k = top_k
self.stream = stream
self.stop_sequences = stop_sequences
self.response_format = response_format
# Store advanced features (optional)
self.guardrail_config = guardrail_config
self.additional_model_request_fields = additional_model_request_fields
self.additional_model_response_field_paths = (
additional_model_response_field_paths
)
# Model-specific settings
self.is_claude_model = "claude" in model.lower()
self.supports_tools = True # Converse API supports tools for most models
self.supports_streaming = True
# Handle inference profiles for newer models
self.model_id = model
return self
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Bedrock-specific fields."""
config = super().to_config_dict()
# NOTE: AWS credentials (access_key, secret_key, session_token) are
# intentionally excluded — they must come from env on resume.
if self.region_name and self.region_name != "us-east-1":
config["region_name"] = self.region_name
if self.max_tokens is not None:
@@ -363,30 +331,6 @@ class BedrockCompletion(BaseLLM):
config["guardrail_config"] = self.guardrail_config
return config
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""
return [] if self.stop_sequences is None else list(self.stop_sequences)
@stop.setter
def stop(self, value: Sequence[str] | str | None) -> None:
"""Set stop sequences.
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
are properly sent to the Bedrock API.
Args:
value: Stop sequences as a Sequence, single string, or None
"""
if value is None:
self.stop_sequences = []
elif isinstance(value, str):
self.stop_sequences = [value]
elif isinstance(value, Sequence):
self.stop_sequences = list(value)
else:
self.stop_sequences = []
def call(
self,
messages: str | list[LLMMessage],
@@ -710,7 +654,7 @@ class BedrockCompletion(BaseLLM):
raise ValueError(f"Invalid message format at index {i}")
# Call Bedrock Converse API with proper error handling
response = self.client.converse(
response = self._client.converse(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
@@ -720,8 +664,9 @@ class BedrockCompletion(BaseLLM):
)
# Track token usage according to AWS response format
if "usage" in response:
self._track_token_usage_internal(response["usage"])
usage = response.get("usage")
if usage:
self._track_token_usage_internal(usage)
stop_reason = response.get("stopReason")
if stop_reason:
@@ -761,6 +706,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return result
except Exception as e:
@@ -783,6 +729,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return non_structured_output_tool_uses
@@ -862,6 +809,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return self._invoke_after_llm_call_hooks(
@@ -992,15 +940,16 @@ class BedrockCompletion(BaseLLM):
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
usage_data: dict[str, Any] | None = None
try:
response = self.client.converse_stream(
response = self._client.converse_stream(
modelId=self.model_id,
messages=cast(
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
cast(object, messages),
),
**body, # type: ignore[arg-type]
**body,
)
stream = response.get("stream")
@@ -1101,6 +1050,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage_data,
)
return result # type: ignore[return-value]
except Exception as e:
@@ -1168,6 +1118,7 @@ class BedrockCompletion(BaseLLM):
metadata = event["metadata"]
if "usage" in metadata:
usage_metrics = metadata["usage"]
usage_data = usage_metrics
self._track_token_usage_internal(usage_metrics)
logging.debug(f"Token usage: {usage_metrics}")
if "trace" in metadata:
@@ -1197,6 +1148,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage_data,
)
return full_response
@@ -1308,8 +1260,9 @@ class BedrockCompletion(BaseLLM):
**body,
)
if "usage" in response:
self._track_token_usage_internal(response["usage"])
usage = response.get("usage")
if usage:
self._track_token_usage_internal(usage)
stop_reason = response.get("stopReason")
if stop_reason:
@@ -1348,6 +1301,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return result
except Exception as e:
@@ -1370,6 +1324,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return non_structured_output_tool_uses
@@ -1444,6 +1399,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage,
)
return text_content
@@ -1564,6 +1520,7 @@ class BedrockCompletion(BaseLLM):
tool_use_id: str | None = None
tool_use_index = 0
accumulated_tool_input = ""
usage_data: dict[str, Any] | None = None
try:
async_client = await self._ensure_async_client()
@@ -1675,6 +1632,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage_data,
)
return result # type: ignore[return-value]
except Exception as e:
@@ -1747,6 +1705,7 @@ class BedrockCompletion(BaseLLM):
metadata = event["metadata"]
if "usage" in metadata:
usage_metrics = metadata["usage"]
usage_data = usage_metrics
self._track_token_usage_internal(usage_metrics)
logging.debug(f"Token usage: {usage_metrics}")
if "trace" in metadata:
@@ -1776,6 +1735,7 @@ class BedrockCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages,
usage=usage_data,
)
return self._invoke_after_llm_call_hooks(

View File

@@ -5,12 +5,13 @@ import json
import logging
import os
import re
from typing import TYPE_CHECKING, Any, Literal, cast
from typing import Any, Literal, cast
from pydantic import BaseModel
from pydantic import BaseModel, Field, PrivateAttr, model_validator
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, llm_call_context
from crewai.llms.hooks.base import BaseInterceptor
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededError,
@@ -19,10 +20,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
try:
from google import genai
from google.genai import types
@@ -44,137 +41,84 @@ class GeminiCompletion(BaseLLM):
offering native function calling, streaming support, and proper Gemini formatting.
"""
def __init__(
self,
model: str = "gemini-2.0-flash-001",
api_key: str | None = None,
project: str | None = None,
location: str | None = None,
temperature: float | None = None,
top_p: float | None = None,
top_k: int | None = None,
max_output_tokens: int | None = None,
stop_sequences: list[str] | None = None,
stream: bool = False,
safety_settings: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[Any, Any] | None = None,
use_vertexai: bool | None = None,
response_format: type[BaseModel] | None = None,
thinking_config: types.ThinkingConfig | None = None,
**kwargs: Any,
):
"""Initialize Google Gemini chat completion client.
model: str = "gemini-2.0-flash-001"
project: str | None = None
location: str | None = None
top_p: float | None = None
top_k: int | None = None
max_output_tokens: int | None = None
stream: bool = False
safety_settings: dict[str, Any] = Field(default_factory=dict)
client_params: dict[str, Any] = Field(default_factory=dict)
interceptor: BaseInterceptor[Any, Any] | None = None
use_vertexai: bool = False
response_format: type[BaseModel] | None = None
thinking_config: Any = None
tools: list[dict[str, Any]] | None = None
supports_tools: bool = False
is_gemini_2_0: bool = False
Args:
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
api_key: Google API key for Gemini API authentication.
Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var.
NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead.
project: Google Cloud project ID for Vertex AI with ADC authentication.
Requires Application Default Credentials (gcloud auth application-default login).
NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC.
If both api_key and project are set, api_key takes precedence.
location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1')
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
top_k: Top-k sampling parameter
max_output_tokens: Maximum tokens in response
stop_sequences: Stop sequences
stream: Enable streaming responses
safety_settings: Safety filter settings
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
Supports parameters like http_options, credentials, debug_config, etc.
interceptor: HTTP interceptor (not yet supported for Gemini).
use_vertexai: Whether to use Vertex AI instead of Gemini API.
- True: Use Vertex AI (with ADC or Express mode with API key)
- False: Use Gemini API (explicitly override env var)
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
When using Vertex AI with API key (Express mode), http_options with
api_version="v1" is automatically configured.
response_format: Pydantic model for structured output. Used as default when
response_model is not passed to call()/acall() methods.
thinking_config: ThinkingConfig for thinking models (gemini-2.5+, gemini-3+).
Controls thought output via include_thoughts, thinking_budget,
and thinking_level. When None, thinking models automatically
get include_thoughts=True so thought content is surfaced.
**kwargs: Additional parameters
"""
if interceptor is not None:
_client: Any = PrivateAttr(default=None)
@model_validator(mode="before")
@classmethod
def _normalize_gemini_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if data.get("interceptor") is not None:
raise NotImplementedError(
"HTTP interceptors are not yet supported for Google Gemini provider. "
"Interceptors are currently supported for OpenAI and Anthropic providers only."
)
super().__init__(
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
# Normalize stop_sequences from stop kwarg
popped = data.pop("stop_sequences", None)
seqs = popped if popped is not None else (data.get("stop") or [])
if isinstance(seqs, str):
seqs = [seqs]
data["stop"] = seqs
# Resolve env vars
data["api_key"] = (
data.get("api_key")
or os.getenv("GOOGLE_API_KEY")
or os.getenv("GEMINI_API_KEY")
)
data["project"] = data.get("project") or os.getenv("GOOGLE_CLOUD_PROJECT")
data["location"] = (
data.get("location") or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
)
# Store client params for later use
self.client_params = client_params or {}
# Get API configuration with environment variable fallbacks
self.api_key = (
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
)
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
if use_vertexai is None:
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
self.client = self._initialize_client(use_vertexai)
# Store completion parameters
self.top_p = top_p
self.top_k = top_k
self.max_output_tokens = max_output_tokens
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None
self.response_format = response_format
use_vx = data.get("use_vertexai")
if use_vx is None:
use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
data["use_vertexai"] = use_vx
# Model-specific settings
model = data.get("model", "gemini-2.0-flash-001")
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
self.supports_tools = bool(
data["supports_tools"] = bool(
version_match and float(version_match.group(1)) >= 1.5
)
self.is_gemini_2_0 = bool(
data["is_gemini_2_0"] = bool(
version_match and float(version_match.group(1)) >= 2.0
)
self.thinking_config = thinking_config
# Auto-enable thinking for gemini-2.5+
if (
self.thinking_config is None
data.get("thinking_config") is None
and version_match
and float(version_match.group(1)) >= 2.5
):
self.thinking_config = types.ThinkingConfig(include_thoughts=True)
data["thinking_config"] = types.ThinkingConfig(include_thoughts=True)
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""
return self.stop_sequences
return data
@stop.setter
def stop(self, value: list[str] | str | None) -> None:
"""Set stop sequences.
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
are properly sent to the Gemini API.
Args:
value: Stop sequences as a list, single string, or None
"""
if value is None:
self.stop_sequences = []
elif isinstance(value, str):
self.stop_sequences = [value]
elif isinstance(value, list):
self.stop_sequences = value
else:
self.stop_sequences = []
@model_validator(mode="after")
def _init_client(self) -> GeminiCompletion:
self._client = self._initialize_client(self.use_vertexai)
return self
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Gemini/Vertex-specific fields."""
@@ -283,8 +227,8 @@ class GeminiCompletion(BaseLLM):
if (
hasattr(self, "client")
and hasattr(self.client, "vertexai")
and self.client.vertexai
and hasattr(self._client, "vertexai")
and self._client.vertexai
):
# Vertex AI configuration
params.update(
@@ -721,6 +665,7 @@ class GeminiCompletion(BaseLLM):
messages_for_event: list[LLMMessage],
from_task: Any | None = None,
from_agent: Any | None = None,
usage: dict[str, Any] | None = None,
) -> BaseModel:
"""Validate content against response model and emit completion event.
@@ -746,6 +691,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
usage=usage,
)
return structured_data
@@ -761,6 +707,7 @@ class GeminiCompletion(BaseLLM):
response_model: type[BaseModel] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
usage: dict[str, Any] | None = None,
) -> str | BaseModel:
"""Finalize completion response with validation and event emission.
@@ -784,6 +731,7 @@ class GeminiCompletion(BaseLLM):
messages_for_event=messages_for_event,
from_task=from_task,
from_agent=from_agent,
usage=usage,
)
self._emit_call_completed_event(
@@ -792,6 +740,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
usage=usage,
)
return self._invoke_after_llm_call_hooks(
@@ -805,6 +754,7 @@ class GeminiCompletion(BaseLLM):
contents: list[types.Content],
from_task: Any | None = None,
from_agent: Any | None = None,
usage: dict[str, Any] | None = None,
) -> BaseModel:
"""Validate and emit event for structured_output tool call.
@@ -829,6 +779,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=self._convert_contents_to_dict(contents),
usage=usage,
)
return validated_data
except Exception as e:
@@ -847,6 +798,7 @@ class GeminiCompletion(BaseLLM):
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
usage: dict[str, Any] | None = None,
) -> str | Any:
"""Process response, execute function calls, and finalize completion.
@@ -887,6 +839,7 @@ class GeminiCompletion(BaseLLM):
contents=contents,
from_task=from_task,
from_agent=from_agent,
usage=usage,
)
# Filter out structured_output from function calls returned to executor
@@ -908,6 +861,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=self._convert_contents_to_dict(contents),
usage=usage,
)
return non_structured_output_parts
@@ -949,6 +903,7 @@ class GeminiCompletion(BaseLLM):
response_model=effective_response_model,
from_task=from_task,
from_agent=from_agent,
usage=usage,
)
def _process_stream_chunk(
@@ -956,10 +911,10 @@ class GeminiCompletion(BaseLLM):
chunk: GenerateContentResponse,
full_response: str,
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, int] | None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int]]:
) -> tuple[str, dict[int, dict[str, Any]], dict[str, int] | None]:
"""Process a single streaming chunk.
Args:
@@ -1035,7 +990,7 @@ class GeminiCompletion(BaseLLM):
self,
full_response: str,
function_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, int] | None,
contents: list[types.Content],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1047,7 +1002,7 @@ class GeminiCompletion(BaseLLM):
Args:
full_response: The complete streamed response content
function_calls: Dictionary of function calls accumulated during streaming
usage_data: Token usage data from the stream
usage_data: Token usage data from the stream, or None if unavailable
contents: Original contents for event conversion
available_functions: Available functions for function calling
from_task: Task that initiated the call
@@ -1057,7 +1012,8 @@ class GeminiCompletion(BaseLLM):
Returns:
Final response content after processing
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
if response_model and function_calls:
for call_data in function_calls.values():
@@ -1069,6 +1025,7 @@ class GeminiCompletion(BaseLLM):
contents=contents,
from_task=from_task,
from_agent=from_agent,
usage=usage_data,
)
non_structured_output_calls = {
@@ -1097,6 +1054,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=self._convert_contents_to_dict(contents),
usage=usage_data,
)
return formatted_function_calls
@@ -1137,6 +1095,7 @@ class GeminiCompletion(BaseLLM):
response_model=effective_response_model,
from_task=from_task,
from_agent=from_agent,
usage=usage_data,
)
def _handle_completion(
@@ -1152,7 +1111,7 @@ class GeminiCompletion(BaseLLM):
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = self.client.models.generate_content(
response = self._client.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1174,6 +1133,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
usage=usage,
)
def _handle_streaming_completion(
@@ -1188,11 +1148,11 @@ class GeminiCompletion(BaseLLM):
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, int] | None = None
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
for chunk in self.client.models.generate_content_stream(
for chunk in self._client.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1230,7 +1190,7 @@ class GeminiCompletion(BaseLLM):
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = await self.client.aio.models.generate_content(
response = await self._client.aio.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1252,6 +1212,7 @@ class GeminiCompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
usage=usage,
)
async def _ahandle_streaming_completion(
@@ -1266,11 +1227,11 @@ class GeminiCompletion(BaseLLM):
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[int, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
usage_data: dict[str, int] | None = None
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
stream = await self.client.aio.models.generate_content_stream(
stream = await self._client.aio.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
@@ -1474,6 +1435,6 @@ class GeminiCompletion(BaseLLM):
try:
from crewai_files.uploaders.gemini import GeminiFileUploader
return GeminiFileUploader(client=self.client)
return GeminiFileUploader(client=self._client)
except ImportError:
return None

View File

@@ -14,10 +14,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
from openai.types.responses import Response
from pydantic import BaseModel
from pydantic import BaseModel, PrivateAttr, model_validator
from crewai.events.types.llm_events import LLMCallType
from crewai.llms.base_llm import BaseLLM, llm_call_context
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
from crewai.utilities.agent_utils import is_context_length_exceeded
from crewai.utilities.exceptions.context_window_exceeding_exception import (
@@ -29,7 +30,6 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
@@ -183,77 +183,69 @@ class OpenAICompletion(BaseLLM):
"computer_use": "computer_use_preview",
}
def __init__(
self,
model: str = "gpt-4o",
api_key: str | None = None,
base_url: str | None = None,
organization: str | None = None,
project: str | None = None,
timeout: float | None = None,
max_retries: int = 2,
default_headers: dict[str, str] | None = None,
default_query: dict[str, Any] | None = None,
client_params: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
frequency_penalty: float | None = None,
presence_penalty: float | None = None,
max_tokens: int | None = None,
max_completion_tokens: int | None = None,
seed: int | None = None,
stream: bool = False,
response_format: dict[str, Any] | type[BaseModel] | None = None,
logprobs: bool | None = None,
top_logprobs: int | None = None,
reasoning_effort: str | None = None,
provider: str | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
api: Literal["completions", "responses"] = "completions",
instructions: str | None = None,
store: bool | None = None,
previous_response_id: str | None = None,
include: list[str] | None = None,
builtin_tools: list[str] | None = None,
parse_tool_outputs: bool = False,
auto_chain: bool = False,
auto_chain_reasoning: bool = False,
**kwargs: Any,
) -> None:
"""Initialize OpenAI completion client."""
model: str = "gpt-4o"
organization: str | None = None
project: str | None = None
timeout: float | None = None
max_retries: int = 2
default_headers: dict[str, str] | None = None
default_query: dict[str, Any] | None = None
client_params: dict[str, Any] | None = None
top_p: float | None = None
frequency_penalty: float | None = None
presence_penalty: float | None = None
max_tokens: int | None = None
max_completion_tokens: int | None = None
seed: int | None = None
stream: bool = False
response_format: JsonResponseFormat | type[BaseModel] | None = None
logprobs: bool | None = None
top_logprobs: int | None = None
reasoning_effort: str | None = None
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
api: Literal["completions", "responses"] = "completions"
instructions: str | None = None
store: bool | None = None
previous_response_id: str | None = None
include: list[str] | None = None
builtin_tools: list[str] | None = None
parse_tool_outputs: bool = False
auto_chain: bool = False
auto_chain_reasoning: bool = False
api_base: str | None = None
is_o1_model: bool = False
is_gpt4_model: bool = False
if provider is None:
provider = kwargs.pop("provider", "openai")
_client: Any = PrivateAttr(default=None)
_async_client: Any = PrivateAttr(default=None)
_last_response_id: str | None = PrivateAttr(default=None)
_last_reasoning_items: list[Any] | None = PrivateAttr(default=None)
self.interceptor = interceptor
# Client configuration attributes
self.organization = organization
self.project = project
self.max_retries = max_retries
self.default_headers = default_headers
self.default_query = default_query
self.client_params = client_params
self.timeout = timeout
self.base_url = base_url
self.api_base = kwargs.pop("api_base", None)
super().__init__(
model=model,
temperature=temperature,
api_key=api_key or os.getenv("OPENAI_API_KEY"),
base_url=base_url,
timeout=timeout,
provider=provider,
**kwargs,
)
@model_validator(mode="before")
@classmethod
def _normalize_openai_fields(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
if not data.get("provider"):
data["provider"] = "openai"
data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY")
# Extract api_base from kwargs if present
if "api_base" not in data:
data["api_base"] = None
model = data.get("model", "gpt-4o")
data["is_o1_model"] = "o1" in model.lower()
data["is_gpt4_model"] = "gpt-4" in model.lower()
return data
@model_validator(mode="after")
def _init_clients(self) -> OpenAICompletion:
client_config = self._get_client_params()
if self.interceptor:
transport = HTTPTransport(interceptor=self.interceptor)
http_client = httpx.Client(transport=transport)
client_config["http_client"] = http_client
self.client = OpenAI(**client_config)
self._client = OpenAI(**client_config)
async_client_config = self._get_client_params()
if self.interceptor:
@@ -261,35 +253,8 @@ class OpenAICompletion(BaseLLM):
async_http_client = httpx.AsyncClient(transport=async_transport)
async_client_config["http_client"] = async_http_client
self.async_client = AsyncOpenAI(**async_client_config)
# Completion parameters
self.top_p = top_p
self.frequency_penalty = frequency_penalty
self.presence_penalty = presence_penalty
self.max_tokens = max_tokens
self.max_completion_tokens = max_completion_tokens
self.seed = seed
self.stream = stream
self.response_format = response_format
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.reasoning_effort = reasoning_effort
self.is_o1_model = "o1" in model.lower()
self.is_gpt4_model = "gpt-4" in model.lower()
# API selection and Responses API parameters
self.api = api
self.instructions = instructions
self.store = store
self.previous_response_id = previous_response_id
self.include = include
self.builtin_tools = builtin_tools
self.parse_tool_outputs = parse_tool_outputs
self.auto_chain = auto_chain
self.auto_chain_reasoning = auto_chain_reasoning
self._last_response_id: str | None = None
self._last_reasoning_items: list[Any] | None = None
self._async_client = AsyncOpenAI(**async_client_config)
return self
@property
def last_response_id(self) -> str | None:
@@ -818,7 +783,7 @@ class OpenAICompletion(BaseLLM):
) -> str | ResponsesAPIResult | Any:
"""Handle non-streaming Responses API call."""
try:
response: Response = self.client.responses.create(**params)
response: Response = self._client.responses.create(**params)
# Track response ID for auto-chaining
if self.auto_chain and response.id:
@@ -844,6 +809,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return parsed_result
@@ -856,6 +822,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return function_calls
@@ -893,6 +860,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return structured_result
except ValueError as e:
@@ -906,6 +874,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
content = self._invoke_after_llm_call_hooks(
@@ -950,7 +919,7 @@ class OpenAICompletion(BaseLLM):
) -> str | ResponsesAPIResult | Any:
"""Handle async non-streaming Responses API call."""
try:
response: Response = await self.async_client.responses.create(**params)
response: Response = await self._async_client.responses.create(**params)
# Track response ID for auto-chaining
if self.auto_chain and response.id:
@@ -976,6 +945,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return parsed_result
@@ -988,6 +958,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return function_calls
@@ -1025,6 +996,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return structured_result
except ValueError as e:
@@ -1038,6 +1010,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
except NotFoundError as e:
@@ -1080,8 +1053,9 @@ class OpenAICompletion(BaseLLM):
full_response = ""
function_calls: list[dict[str, Any]] = []
final_response: Response | None = None
usage: dict[str, Any] | None = None
stream = self.client.responses.create(**params)
stream = self._client.responses.create(**params)
response_id_stream = None
for event in stream:
@@ -1137,6 +1111,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return parsed_result
@@ -1173,6 +1148,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return structured_result
except ValueError as e:
@@ -1186,6 +1162,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return self._invoke_after_llm_call_hooks(
@@ -1204,8 +1181,9 @@ class OpenAICompletion(BaseLLM):
full_response = ""
function_calls: list[dict[str, Any]] = []
final_response: Response | None = None
usage: dict[str, Any] | None = None
stream = await self.async_client.responses.create(**params)
stream = await self._async_client.responses.create(**params)
response_id_stream = None
async for event in stream:
@@ -1261,6 +1239,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return parsed_result
@@ -1297,6 +1276,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return structured_result
except ValueError as e:
@@ -1310,6 +1290,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params.get("input", []),
usage=usage,
)
return full_response
@@ -1595,7 +1576,7 @@ class OpenAICompletion(BaseLLM):
parse_params = {
k: v for k, v in params.items() if k != "response_format"
}
parsed_response = self.client.beta.chat.completions.parse(
parsed_response = self._client.beta.chat.completions.parse(
**parse_params,
response_format=response_model,
)
@@ -1615,10 +1596,11 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return parsed_object
response: ChatCompletion = self.client.chat.completions.create(**params)
response: ChatCompletion = self._client.chat.completions.create(**params)
usage = self._extract_openai_token_usage(response)
@@ -1636,6 +1618,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return list(message.tool_calls)
@@ -1674,6 +1657,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_result
except ValueError as e:
@@ -1687,6 +1671,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
if usage.get("total_tokens", 0) > 0:
@@ -1728,7 +1713,7 @@ class OpenAICompletion(BaseLLM):
self,
full_response: str,
tool_calls: dict[int, dict[str, Any]],
usage_data: dict[str, int],
usage_data: dict[str, Any] | None,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1739,7 +1724,7 @@ class OpenAICompletion(BaseLLM):
Args:
full_response: The accumulated text response from the stream.
tool_calls: Accumulated tool calls from the stream, keyed by index.
usage_data: Token usage data from the stream.
usage_data: Token usage data from the stream, or None if unavailable.
params: The completion parameters containing messages.
available_functions: Available functions for tool calling.
from_task: Task that initiated the call.
@@ -1750,7 +1735,8 @@ class OpenAICompletion(BaseLLM):
tool execution result when available_functions is provided,
or the text response string.
"""
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
if tool_calls and not available_functions:
tool_calls_list = [
@@ -1771,6 +1757,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return tool_calls_list
@@ -1813,6 +1800,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return full_response
@@ -1837,7 +1825,7 @@ class OpenAICompletion(BaseLLM):
}
stream: ChatCompletionStream[BaseModel]
with self.client.beta.chat.completions.stream(
with self._client.beta.chat.completions.stream(
**parse_params, response_format=response_model
) as stream:
for chunk in stream:
@@ -1866,6 +1854,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return parsed_result
@@ -1873,10 +1862,10 @@ class OpenAICompletion(BaseLLM):
return ""
completion_stream: Stream[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
self._client.chat.completions.create(**params)
)
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
for completion_chunk in completion_stream:
response_id_stream = (
@@ -1970,7 +1959,7 @@ class OpenAICompletion(BaseLLM):
parse_params = {
k: v for k, v in params.items() if k != "response_format"
}
parsed_response = await self.async_client.beta.chat.completions.parse(
parsed_response = await self._async_client.beta.chat.completions.parse(
**parse_params,
response_format=response_model,
)
@@ -1990,10 +1979,11 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return parsed_object
response: ChatCompletion = await self.async_client.chat.completions.create(
response: ChatCompletion = await self._async_client.chat.completions.create(
**params
)
@@ -2013,6 +2003,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return list(message.tool_calls)
@@ -2051,6 +2042,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
return structured_result
except ValueError as e:
@@ -2064,6 +2056,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage,
)
if usage.get("total_tokens", 0) > 0:
@@ -2111,10 +2104,10 @@ class OpenAICompletion(BaseLLM):
if response_model:
completion_stream: AsyncIterator[
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
] = await self._async_client.chat.completions.create(**params)
accumulated_content = ""
usage_data = {"total_tokens": 0}
usage_data: dict[str, Any] | None = None
async for chunk in completion_stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None
@@ -2137,7 +2130,8 @@ class OpenAICompletion(BaseLLM):
response_id=response_id_stream,
)
self._track_token_usage_internal(usage_data)
if usage_data:
self._track_token_usage_internal(usage_data)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
@@ -2148,6 +2142,7 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return parsed_object
@@ -2159,14 +2154,15 @@ class OpenAICompletion(BaseLLM):
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
usage=usage_data,
)
return accumulated_content
stream: AsyncIterator[
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
] = await self._async_client.chat.completions.create(**params)
usage_data = {"total_tokens": 0}
usage_data = None
async for chunk in stream:
response_id_stream = chunk.id if hasattr(chunk, "id") else None
@@ -2356,8 +2352,8 @@ class OpenAICompletion(BaseLLM):
from crewai_files.uploaders.openai import OpenAIFileUploader
return OpenAIFileUploader(
client=self.client,
async_client=self.async_client,
client=self._client,
async_client=self._async_client,
)
except ImportError:
return None

View File

@@ -16,6 +16,8 @@ from dataclasses import dataclass, field
import os
from typing import Any
from pydantic import model_validator
from crewai.llms.providers.openai.completion import OpenAICompletion
@@ -140,31 +142,13 @@ class OpenAICompatibleCompletion(OpenAICompletion):
)
"""
def __init__(
self,
model: str,
provider: str,
api_key: str | None = None,
base_url: str | None = None,
default_headers: dict[str, str] | None = None,
**kwargs: Any,
) -> None:
"""Initialize OpenAI-compatible completion client.
@model_validator(mode="before")
@classmethod
def _resolve_provider_config(cls, data: Any) -> Any:
if not isinstance(data, dict):
return data
Args:
model: The model identifier.
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
api_key: Optional API key override. If not provided, uses the
provider's configured environment variable.
base_url: Optional base URL override. If not provided, uses the
provider's configured default or environment variable.
default_headers: Optional headers to merge with provider defaults.
**kwargs: Additional arguments passed to OpenAICompletion.
Raises:
ValueError: If the provider is not supported or required API key
is missing.
"""
provider = data.get("provider", "")
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
if config is None:
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
@@ -173,21 +157,15 @@ class OpenAICompatibleCompletion(OpenAICompletion):
f"Supported providers: {supported}"
)
resolved_api_key = self._resolve_api_key(api_key, config, provider)
resolved_base_url = self._resolve_base_url(base_url, config, provider)
resolved_headers = self._resolve_headers(default_headers, config)
super().__init__(
model=model,
provider=provider,
api_key=resolved_api_key,
base_url=resolved_base_url,
default_headers=resolved_headers,
**kwargs,
data["api_key"] = cls._resolve_api_key(data.get("api_key"), config, provider)
data["base_url"] = cls._resolve_base_url(data.get("base_url"), config, provider)
data["default_headers"] = cls._resolve_headers(
data.get("default_headers"), config
)
return data
@staticmethod
def _resolve_api_key(
self,
api_key: str | None,
config: ProviderConfig,
provider: str,
@@ -220,8 +198,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
return config.default_api_key
@staticmethod
def _resolve_base_url(
self,
base_url: str | None,
config: ProviderConfig,
provider: str,
@@ -249,8 +227,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
return resolved
@staticmethod
def _resolve_headers(
self,
headers: dict[str, str] | None,
config: ProviderConfig,
) -> dict[str, str] | None:

View File

@@ -1 +0,0 @@
"""Third-party LLM implementations for crewAI."""

View File

@@ -98,7 +98,7 @@ class EncodingFlow(Flow[EncodingState]):
_skip_auto_memory: bool = True
initial_state = EncodingState
initial_state: type[EncodingState] = EncodingState
def __init__(
self,

View File

@@ -65,7 +65,7 @@ class RecallFlow(Flow[RecallState]):
_skip_auto_memory: bool = True
initial_state = RecallState
initial_state: type[RecallState] = RecallState
def __init__(
self,

View File

@@ -148,6 +148,36 @@ class Memory(BaseModel):
_pending_saves: list[Future[Any]] = PrivateAttr(default_factory=list)
_pending_lock: threading.Lock = PrivateAttr(default_factory=threading.Lock)
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> Memory:
"""Deepcopy that handles unpickleable private attrs (ThreadPoolExecutor, Lock)."""
import copy as _copy
cls = type(self)
new = cls.__new__(cls)
if memo is None:
memo = {}
memo[id(self)] = new
object.__setattr__(new, "__dict__", _copy.deepcopy(self.__dict__, memo))
object.__setattr__(
new, "__pydantic_fields_set__", _copy.copy(self.__pydantic_fields_set__)
)
object.__setattr__(
new, "__pydantic_extra__", _copy.deepcopy(self.__pydantic_extra__, memo)
)
# Private attrs: create fresh pool/lock instead of deepcopying
private = {}
for k, v in (self.__pydantic_private__ or {}).items():
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
attr = self.__private_attributes__[k]
private[k] = attr.get_default()
else:
try:
private[k] = _copy.deepcopy(v, memo)
except Exception:
private[k] = v
object.__setattr__(new, "__pydantic_private__", private)
return new
def model_post_init(self, __context: Any) -> None:
"""Initialize runtime state from field values."""
self._config = MemoryConfig(

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
from collections import defaultdict
from typing import TYPE_CHECKING, Any
from pydantic import BaseModel, Field, InstanceOf
from pydantic import BaseModel, Field
from rich.box import HEAVY_EDGE
from rich.console import Console
from rich.table import Table
@@ -39,9 +39,9 @@ class CrewEvaluator:
def __init__(
self,
crew: Crew,
eval_llm: InstanceOf[BaseLLM] | str | None = None,
eval_llm: BaseLLM | str | None = None,
openai_model_name: str | None = None,
llm: InstanceOf[BaseLLM] | str | None = None,
llm: BaseLLM | str | None = None,
) -> None:
self.crew = crew
self.llm = eval_llm

View File

@@ -2,9 +2,10 @@
from __future__ import annotations
from typing import Annotated, Any, Literal, TypedDict
from typing import Annotated, Any, Literal
from pydantic import BaseModel, Field
from typing_extensions import TypedDict
from crewai.utilities.i18n import I18N, get_i18n

View File

@@ -1692,9 +1692,27 @@ def test_agent_with_knowledge_sources_works_with_copy():
) as mock_knowledge_storage:
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
agent.knowledge_storage = mock_knowledge_storage_instance
class _StubStorage(BaseKnowledgeStorage):
def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
return []
async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
return []
def save(self, documents):
pass
async def asave(self, documents):
pass
def reset(self):
pass
async def areset(self):
pass
mock_knowledge_storage.return_value = _StubStorage()
agent.knowledge_storage = _StubStorage()
agent_copy = agent.copy()

View File

@@ -4,13 +4,55 @@ Tests the Flow-based agent executor implementation including state management,
flow methods, routing logic, and error handling.
"""
from __future__ import annotations
import asyncio
import time
from typing import Any
from unittest.mock import AsyncMock, Mock, patch
import pytest
from crewai.agents.tools_handler import ToolsHandler as _ToolsHandler
from crewai.agents.step_executor import StepExecutor
def _build_executor(**kwargs: Any) -> AgentExecutor:
"""Create an AgentExecutor without validation — for unit tests.
Uses model_construct to skip Pydantic validators so plain Mock()
objects are accepted for typed fields like llm, agent, crew, task.
"""
executor = AgentExecutor.model_construct(**kwargs)
executor._state = AgentExecutorState()
executor._methods = {}
executor._method_outputs = []
executor._completed_methods = set()
executor._fired_or_listeners = set()
executor._pending_and_listeners = {}
executor._method_execution_counts = {}
executor._method_call_counts = {}
executor._event_futures = []
executor._human_feedback_method_outputs = {}
executor._input_history = []
executor._is_execution_resuming = False
import threading
executor._state_lock = threading.Lock()
executor._or_listeners_lock = threading.Lock()
executor._execution_lock = threading.Lock()
executor._finalize_lock = threading.Lock()
executor._finalize_called = False
executor._is_executing = False
executor._has_been_invoked = False
executor._last_parser_error = None
executor._last_context_error = None
executor._step_executor = None
executor._planner_observer = None
from crewai.utilities.printer import Printer
executor._printer = Printer()
from crewai.utilities.i18n import get_i18n
executor._i18n = kwargs.get("i18n") or get_i18n()
return executor
from crewai.agents.planner_observer import PlannerObserver
from crewai.experimental.agent_executor import (
AgentExecutorState,
@@ -75,6 +117,7 @@ class TestAgentExecutor:
"""Create mock dependencies for executor."""
llm = Mock()
llm.supports_stop_words.return_value = True
llm.stop = []
task = Mock()
task.description = "Test task"
@@ -94,7 +137,7 @@ class TestAgentExecutor:
prompt = {"prompt": "Test prompt with {input}, {tool_names}, {tools}"}
tools = []
tools_handler = Mock()
tools_handler = Mock(spec=_ToolsHandler)
return {
"llm": llm,
@@ -112,7 +155,7 @@ class TestAgentExecutor:
def test_executor_initialization(self, mock_dependencies):
"""Test AgentExecutor initialization."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.llm == mock_dependencies["llm"]
assert executor.task == mock_dependencies["task"]
@@ -126,7 +169,7 @@ class TestAgentExecutor:
with patch.object(
AgentExecutor, "_show_start_logs"
) as mock_show_start:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.initialize_reasoning()
assert result == "initialized"
@@ -134,7 +177,7 @@ class TestAgentExecutor:
def test_check_max_iterations_not_reached(self, mock_dependencies):
"""Test routing when iterations < max."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.iterations = 5
result = executor.check_max_iterations()
@@ -142,7 +185,7 @@ class TestAgentExecutor:
def test_check_max_iterations_reached(self, mock_dependencies):
"""Test routing when iterations >= max."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.iterations = 10
result = executor.check_max_iterations()
@@ -150,7 +193,7 @@ class TestAgentExecutor:
def test_route_by_answer_type_action(self, mock_dependencies):
"""Test routing for AgentAction."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
@@ -160,7 +203,7 @@ class TestAgentExecutor:
def test_route_by_answer_type_finish(self, mock_dependencies):
"""Test routing for AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thoughts", output="Final answer", text="complete"
)
@@ -170,7 +213,7 @@ class TestAgentExecutor:
def test_continue_iteration(self, mock_dependencies):
"""Test iteration continuation."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.continue_iteration()
@@ -179,7 +222,7 @@ class TestAgentExecutor:
def test_finalize_success(self, mock_dependencies):
"""Test finalize with valid AgentFinish."""
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentFinish(
thought="final thinking", output="Done", text="complete"
)
@@ -192,7 +235,7 @@ class TestAgentExecutor:
def test_finalize_failure(self, mock_dependencies):
"""Test finalize skips when given AgentAction instead of AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="search", tool_input="query", text="action text"
)
@@ -208,7 +251,7 @@ class TestAgentExecutor:
):
"""Finalize should skip synthesis when last todo is already a complete answer."""
with patch.object(AgentExecutor, "_show_logs") as mock_show_logs:
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.todos.items = [
TodoItem(
step_number=1,
@@ -252,7 +295,7 @@ class TestAgentExecutor:
):
"""Finalize should still synthesize when response_model is configured."""
with patch.object(AgentExecutor, "_show_logs"):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.response_model = Mock()
executor.state.todos.items = [
TodoItem(
@@ -287,7 +330,7 @@ class TestAgentExecutor:
def test_format_prompt(self, mock_dependencies):
"""Test prompt formatting."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
inputs = {"input": "test input", "tool_names": "tool1, tool2", "tools": "desc"}
result = executor._format_prompt("Prompt {input} {tool_names} {tools}", inputs)
@@ -298,18 +341,18 @@ class TestAgentExecutor:
def test_is_training_mode_false(self, mock_dependencies):
"""Test training mode detection when not in training."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor._is_training_mode() is False
def test_is_training_mode_true(self, mock_dependencies):
"""Test training mode detection when in training."""
mock_dependencies["crew"]._train = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor._is_training_mode() is True
def test_append_message_to_state(self, mock_dependencies):
"""Test message appending to state."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
initial_count = len(executor.state.messages)
executor._append_message_to_state("test message")
@@ -322,7 +365,7 @@ class TestAgentExecutor:
callback = Mock()
mock_dependencies["step_callback"] = callback
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
answer = AgentFinish(thought="thinking", output="test", text="final")
executor._invoke_step_callback(answer)
@@ -332,7 +375,7 @@ class TestAgentExecutor:
def test_invoke_step_callback_none(self, mock_dependencies):
"""Test step callback when none provided."""
mock_dependencies["step_callback"] = None
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
# Should not raise error
executor._invoke_step_callback(
@@ -346,7 +389,7 @@ class TestAgentExecutor:
"""Test async step callback scheduling when already in an event loop."""
callback = AsyncMock()
mock_dependencies["step_callback"] = callback
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
answer = AgentFinish(thought="thinking", output="test", text="final")
with patch("crewai.experimental.agent_executor.asyncio.run") as mock_run:
@@ -364,6 +407,7 @@ class TestStepExecutorCriticalFixes:
def mock_dependencies(self):
"""Create mock dependencies for AgentExecutor tests in this class."""
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -393,6 +437,7 @@ class TestStepExecutorCriticalFixes:
@pytest.fixture
def step_executor(self):
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
agent = Mock()
@@ -485,7 +530,7 @@ class TestStepExecutorCriticalFixes:
mock_handle_exception.return_value = None
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor._last_parser_error = OutputParserError("test error")
initial_iterations = executor.state.iterations
@@ -500,7 +545,7 @@ class TestStepExecutorCriticalFixes:
self, mock_handle_context, mock_dependencies
):
"""Test recovery from context length error."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor._last_context_error = Exception("context too long")
initial_iterations = executor.state.iterations
@@ -513,16 +558,16 @@ class TestStepExecutorCriticalFixes:
def test_use_stop_words_property(self, mock_dependencies):
"""Test use_stop_words property."""
mock_dependencies["llm"].supports_stop_words.return_value = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.use_stop_words is True
mock_dependencies["llm"].supports_stop_words.return_value = False
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
assert executor.use_stop_words is False
def test_compatibility_properties(self, mock_dependencies):
"""Test compatibility properties for mixin."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.messages = [{"role": "user", "content": "test"}]
executor.state.iterations = 5
@@ -538,6 +583,7 @@ class TestFlowErrorHandling:
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -575,7 +621,7 @@ class TestFlowErrorHandling:
mock_enforce_rpm.return_value = None
mock_get_llm.side_effect = OutputParserError("parse failed")
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "parser_error"
@@ -596,7 +642,7 @@ class TestFlowErrorHandling:
mock_get_llm.side_effect = Exception("context length")
mock_is_context_exceeded.return_value = True
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
result = executor.call_llm_and_parse()
assert result == "context_error"
@@ -610,6 +656,7 @@ class TestFlowInvoke:
def mock_dependencies(self):
"""Create mock dependencies."""
llm = Mock()
llm.stop = []
task = Mock()
task.description = "Test"
task.human_input = False
@@ -646,7 +693,7 @@ class TestFlowInvoke:
mock_dependencies,
):
"""Test successful invoke without human feedback."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
# Mock kickoff to set the final answer in state
def mock_kickoff_side_effect():
@@ -666,7 +713,7 @@ class TestFlowInvoke:
@patch.object(AgentExecutor, "kickoff")
def test_invoke_failure_no_agent_finish(self, mock_kickoff, mock_dependencies):
"""Test invoke fails without AgentFinish."""
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
executor.state.current_answer = AgentAction(
thought="thinking", tool="test", tool_input="test", text="action text"
)
@@ -689,7 +736,7 @@ class TestFlowInvoke:
"system": "System: {input}",
"user": "User: {input} {tool_names} {tools}",
}
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def mock_kickoff_side_effect():
executor.state.current_answer = AgentFinish(
@@ -713,6 +760,7 @@ class TestNativeToolExecution:
@pytest.fixture
def mock_dependencies(self):
llm = Mock()
llm.stop = []
llm.supports_stop_words.return_value = True
task = Mock()
@@ -734,7 +782,7 @@ class TestNativeToolExecution:
prompt = {"prompt": "Test {input} {tool_names} {tools}"}
tools_handler = Mock()
tools_handler = Mock(spec=_ToolsHandler)
tools_handler.cache = None
return {
@@ -754,7 +802,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_runs_parallel_for_multiple_calls(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def slow_one() -> str:
time.sleep(0.2)
@@ -790,7 +838,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_falls_back_to_sequential_for_result_as_answer(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
def slow_one() -> str:
time.sleep(0.2)
@@ -832,7 +880,7 @@ class TestNativeToolExecution:
def test_execute_native_tool_result_as_answer_short_circuits_remaining_calls(
self, mock_dependencies
):
executor = AgentExecutor(**mock_dependencies)
executor = _build_executor(**mock_dependencies)
call_counts = {"slow_one": 0, "slow_two": 0}
def slow_one() -> str:
@@ -879,30 +927,6 @@ class TestNativeToolExecution:
assert len(tool_messages) == 1
assert tool_messages[0]["tool_call_id"] == "call_1"
def test_check_native_todo_completion_requires_current_todo(
self, mock_dependencies
):
from crewai.utilities.planning_types import TodoList
executor = AgentExecutor(**mock_dependencies)
# No current todo → not satisfied
executor.state.todos = TodoList(items=[])
assert executor.check_native_todo_completion() == "todo_not_satisfied"
# With a current todo that has tool_to_use → satisfied
running = TodoItem(
step_number=1,
description="Use the expected tool",
tool_to_use="expected_tool",
status="running",
)
executor.state.todos = TodoList(items=[running])
assert executor.check_native_todo_completion() == "todo_satisfied"
# With a current todo without tool_to_use → still satisfied
running.tool_to_use = None
assert executor.check_native_todo_completion() == "todo_satisfied"
class TestPlannerObserver:

View File

@@ -1,7 +1,11 @@
interactions:
- request:
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
uses tools. This is padding text to ensure the prompt is large enough for caching.
body: '{"input":[{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","instructions":"You
are a helpful assistant that uses tools. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
@@ -68,13 +72,9 @@ interactions:
for caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is large
enough for caching. "},{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
city name"}},"required":["location"],"additionalProperties":false}}}]}'
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
city name"}},"required":["location"]}}]}'
headers:
User-Agent:
- X-USER-AGENT-XXX
@@ -87,7 +87,7 @@ interactions:
connection:
- keep-alive
content-length:
- '6158'
- '6065'
content-type:
- application/json
host:
@@ -109,26 +109,113 @@ interactions:
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.3
- 3.13.12
method: POST
uri: https://api.openai.com/v1/chat/completions
uri: https://api.openai.com/v1/responses
response:
body:
string: "{\n \"id\": \"chatcmpl-D7mXQCgT3p3ViImkiqDiZGqLREQtp\",\n \"object\":
\"chat.completion\",\n \"created\": 1770747248,\n \"model\": \"gpt-4.1-2025-04-14\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
\ \"id\": \"call_9ZqMavn3J1fBnQEaqpYol0Bd\",\n \"type\":
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
\ \"arguments\": \"{\\\"location\\\":\\\"Tokyo\\\"}\"\n }\n
\ }\n ],\n \"refusal\": null,\n \"annotations\":
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
string: "{\n \"id\": \"resp_0d68149bcc0d14810069caf464a4b48197bd9f098abb2f6303\",\n
\ \"object\": \"response\",\n \"created_at\": 1774908516,\n \"status\":
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
\"developer\"\n },\n \"completed_at\": 1774908517,\n \"error\": null,\n
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
\"You are a helpful assistant that uses tools. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
\ \"output\": [\n {\n \"id\": \"fc_0d68149bcc0d14810069caf46568088197a33be67f16a1fa09\",\n
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
\"{\\\"location\\\":\\\"Tokyo\\\"}\",\n \"call_id\": \"call_74rwmYse0DE4JFaFGyAFx9bu\",\n
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
\"function\",\n \"description\": \"Get the current weather for a location\",\n
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
\"string\",\n \"description\": \"The city name\"\n }\n
\ },\n \"required\": [\n \"location\"\n ],\n
\ \"additionalProperties\": false\n },\n \"strict\": true\n
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
{\n \"cached_tokens\": 0\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
\ \"user\": null,\n \"metadata\": {}\n}"
headers:
CF-RAY:
- CF-RAY-XXX
@@ -137,7 +224,7 @@ interactions:
Content-Type:
- application/json
Date:
- Tue, 10 Feb 2026 18:14:08 GMT
- Mon, 30 Mar 2026 22:08:37 GMT
Server:
- cloudflare
Strict-Transport-Security:
@@ -146,8 +233,6 @@ interactions:
- chunked
X-Content-Type-Options:
- X-CONTENT-TYPE-XXX
access-control-expose-headers:
- ACCESS-CONTROL-XXX
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
@@ -155,15 +240,13 @@ interactions:
openai-organization:
- OPENAI-ORG-XXX
openai-processing-ms:
- '484'
- '1085'
openai-project:
- OPENAI-PROJECT-XXX
openai-version:
- '2020-10-01'
set-cookie:
- SET-COOKIE-XXX
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-requests:
- X-RATELIMIT-LIMIT-REQUESTS-XXX
x-ratelimit-limit-tokens:
@@ -182,8 +265,12 @@ interactions:
code: 200
message: OK
- request:
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
uses tools. This is padding text to ensure the prompt is large enough for caching.
body: '{"input":[{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","instructions":"You
are a helpful assistant that uses tools. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
@@ -250,13 +337,9 @@ interactions:
for caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is large
enough for caching. "},{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
city name"}},"required":["location"],"additionalProperties":false}}}]}'
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
city name"}},"required":["location"]}}]}'
headers:
User-Agent:
- X-USER-AGENT-XXX
@@ -269,7 +352,7 @@ interactions:
connection:
- keep-alive
content-length:
- '6158'
- '6065'
content-type:
- application/json
cookie:
@@ -293,26 +376,113 @@ interactions:
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.3
- 3.13.12
method: POST
uri: https://api.openai.com/v1/chat/completions
uri: https://api.openai.com/v1/responses
response:
body:
string: "{\n \"id\": \"chatcmpl-D7mXR8k9vk8TlGvGXlrQSI7iNeAN1\",\n \"object\":
\"chat.completion\",\n \"created\": 1770747249,\n \"model\": \"gpt-4.1-2025-04-14\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
\ \"id\": \"call_6PeUBlRPG8JcV2lspmLjJbnn\",\n \"type\":
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
\ \"arguments\": \"{\\\"location\\\":\\\"Paris\\\"}\"\n }\n
\ }\n ],\n \"refusal\": null,\n \"annotations\":
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
string: "{\n \"id\": \"resp_0525bf798202137e0069caf465ee3c8196aa7c83da1c369eb7\",\n
\ \"object\": \"response\",\n \"created_at\": 1774908517,\n \"status\":
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
\"developer\"\n },\n \"completed_at\": 1774908518,\n \"error\": null,\n
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
\"You are a helpful assistant that uses tools. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. This is
padding text to ensure the prompt is large enough for caching. This is padding
text to ensure the prompt is large enough for caching. This is padding text
to ensure the prompt is large enough for caching. This is padding text to
ensure the prompt is large enough for caching. This is padding text to ensure
the prompt is large enough for caching. This is padding text to ensure the
prompt is large enough for caching. This is padding text to ensure the prompt
is large enough for caching. This is padding text to ensure the prompt is
large enough for caching. This is padding text to ensure the prompt is large
enough for caching. This is padding text to ensure the prompt is large enough
for caching. This is padding text to ensure the prompt is large enough for
caching. This is padding text to ensure the prompt is large enough for caching.
This is padding text to ensure the prompt is large enough for caching. This
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
\ \"output\": [\n {\n \"id\": \"fc_0525bf798202137e0069caf46666588196a2ec20dc515a6a91\",\n
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
\"{\\\"location\\\":\\\"Paris\\\"}\",\n \"call_id\": \"call_LJAGuYYZPjNxSgg0TUgGpT44\",\n
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
\"function\",\n \"description\": \"Get the current weather for a location\",\n
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
\"string\",\n \"description\": \"The city name\"\n }\n
\ },\n \"required\": [\n \"location\"\n ],\n
\ \"additionalProperties\": false\n },\n \"strict\": true\n
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
{\n \"cached_tokens\": 1152\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
\ \"user\": null,\n \"metadata\": {}\n}"
headers:
CF-RAY:
- CF-RAY-XXX
@@ -321,7 +491,7 @@ interactions:
Content-Type:
- application/json
Date:
- Tue, 10 Feb 2026 18:14:09 GMT
- Mon, 30 Mar 2026 22:08:38 GMT
Server:
- cloudflare
Strict-Transport-Security:
@@ -330,8 +500,6 @@ interactions:
- chunked
X-Content-Type-Options:
- X-CONTENT-TYPE-XXX
access-control-expose-headers:
- ACCESS-CONTROL-XXX
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
@@ -339,15 +507,11 @@ interactions:
openai-organization:
- OPENAI-ORG-XXX
openai-processing-ms:
- '528'
- '653'
openai-project:
- OPENAI-PROJECT-XXX
openai-version:
- '2020-10-01'
set-cookie:
- SET-COOKIE-XXX
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-requests:
- X-RATELIMIT-LIMIT-REQUESTS-XXX
x-ratelimit-limit-tokens:

View File

@@ -0,0 +1,108 @@
interactions:
- request:
body: '{"messages":[{"role":"user","content":"Say hello"}],"model":"gpt-4o-mini"}'
headers:
User-Agent:
- X-USER-AGENT-XXX
accept:
- application/json
accept-encoding:
- ACCEPT-ENCODING-XXX
authorization:
- AUTHORIZATION-XXX
connection:
- keep-alive
content-length:
- '74'
content-type:
- application/json
host:
- api.openai.com
x-stainless-arch:
- X-STAINLESS-ARCH-XXX
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- X-STAINLESS-OS-XXX
x-stainless-package-version:
- 1.83.0
x-stainless-read-timeout:
- X-STAINLESS-READ-TIMEOUT-XXX
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.13.2
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
body:
string: "{\n \"id\": \"chatcmpl-DPS8YQSwQ3pZKZztIoIe1eYodMqh2\",\n \"object\":
\"chat.completion\",\n \"created\": 1774958730,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"Hello! How can I assist you today?\",\n
\ \"refusal\": null,\n \"annotations\": []\n },\n \"logprobs\":
null,\n \"finish_reason\": \"stop\"\n }\n ],\n \"usage\": {\n \"prompt_tokens\":
9,\n \"completion_tokens\": 9,\n \"total_tokens\": 18,\n \"prompt_tokens_details\":
{\n \"cached_tokens\": 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_709f182cb4\"\n}\n"
headers:
CF-Cache-Status:
- DYNAMIC
CF-Ray:
- 9e4f38fc5d9d82e8-GIG
Connection:
- keep-alive
Content-Type:
- application/json
Date:
- Tue, 31 Mar 2026 12:05:30 GMT
Server:
- cloudflare
Strict-Transport-Security:
- STS-XXX
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- X-CONTENT-TYPE-XXX
access-control-expose-headers:
- ACCESS-CONTROL-XXX
alt-svc:
- h3=":443"; ma=86400
content-length:
- '839'
openai-organization:
- OPENAI-ORG-XXX
openai-processing-ms:
- '680'
openai-project:
- OPENAI-PROJECT-XXX
openai-version:
- '2020-10-01'
set-cookie:
- SET-COOKIE-XXX
x-openai-proxy-wasm:
- v0.1
x-ratelimit-limit-requests:
- X-RATELIMIT-LIMIT-REQUESTS-XXX
x-ratelimit-limit-tokens:
- X-RATELIMIT-LIMIT-TOKENS-XXX
x-ratelimit-remaining-requests:
- X-RATELIMIT-REMAINING-REQUESTS-XXX
x-ratelimit-remaining-tokens:
- X-RATELIMIT-REMAINING-TOKENS-XXX
x-ratelimit-reset-requests:
- X-RATELIMIT-RESET-REQUESTS-XXX
x-ratelimit-reset-tokens:
- X-RATELIMIT-RESET-TOKENS-XXX
x-request-id:
- X-REQUEST-ID-XXX
status:
code: 200
message: OK
version: 1

View File

@@ -0,0 +1,176 @@
from typing import Any
from unittest.mock import patch
import pytest
from pydantic import BaseModel
from crewai.events.event_bus import CrewAIEventsBus
from crewai.events.types.llm_events import LLMCallCompletedEvent, LLMCallType
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
class TestLLMCallCompletedEventUsageField:
def test_accepts_usage_dict(self):
event = LLMCallCompletedEvent(
response="hello",
call_type=LLMCallType.LLM_CALL,
call_id="test-id",
usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30},
)
assert event.usage == {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}
def test_usage_defaults_to_none(self):
event = LLMCallCompletedEvent(
response="hello",
call_type=LLMCallType.LLM_CALL,
call_id="test-id",
)
assert event.usage is None
def test_accepts_none_usage(self):
event = LLMCallCompletedEvent(
response="hello",
call_type=LLMCallType.LLM_CALL,
call_id="test-id",
usage=None,
)
assert event.usage is None
def test_accepts_nested_usage_dict(self):
usage = {
"prompt_tokens": 100,
"completion_tokens": 200,
"total_tokens": 300,
"prompt_tokens_details": {"cached_tokens": 50},
}
event = LLMCallCompletedEvent(
response="hello",
call_type=LLMCallType.LLM_CALL,
call_id="test-id",
usage=usage,
)
assert event.usage["prompt_tokens_details"]["cached_tokens"] == 50
class TestUsageToDict:
def test_none_returns_none(self):
assert LLM._usage_to_dict(None) is None
def test_dict_passes_through(self):
usage = {"prompt_tokens": 10, "total_tokens": 30}
assert LLM._usage_to_dict(usage) is usage
def test_pydantic_model_uses_model_dump(self):
class Usage(BaseModel):
prompt_tokens: int = 10
completion_tokens: int = 20
total_tokens: int = 30
result = LLM._usage_to_dict(Usage())
assert result == {
"prompt_tokens": 10,
"completion_tokens": 20,
"total_tokens": 30,
}
def test_object_with_dict_attr(self):
class UsageObj:
def __init__(self):
self.prompt_tokens = 5
self.completion_tokens = 15
self.total_tokens = 20
result = LLM._usage_to_dict(UsageObj())
assert result == {
"prompt_tokens": 5,
"completion_tokens": 15,
"total_tokens": 20,
}
def test_object_with_dict_excludes_private_attrs(self):
class UsageObj:
def __init__(self):
self.total_tokens = 42
self._internal = "hidden"
result = LLM._usage_to_dict(UsageObj())
assert result == {"total_tokens": 42}
assert "_internal" not in result
def test_unsupported_type_returns_none(self):
assert LLM._usage_to_dict(42) is None
assert LLM._usage_to_dict("string") is None
class _StubLLM(BaseLLM):
"""Minimal concrete BaseLLM for testing event emission."""
model: str = "test-model"
def call(self, *args: Any, **kwargs: Any) -> str:
return ""
async def acall(self, *args: Any, **kwargs: Any) -> str:
return ""
def supports_function_calling(self) -> bool:
return False
def supports_stop_words(self) -> bool:
return True
class TestEmitCallCompletedEventPassesUsage:
@pytest.fixture
def mock_emit(self):
with patch.object(CrewAIEventsBus, "emit") as mock:
yield mock
@pytest.fixture
def llm(self):
return _StubLLM(model="test-model")
def test_usage_is_passed_to_event(self, mock_emit, llm):
usage_data = {"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}
llm._emit_call_completed_event(
response="hello",
call_type=LLMCallType.LLM_CALL,
messages="test prompt",
usage=usage_data,
)
mock_emit.assert_called_once()
event = mock_emit.call_args[1]["event"]
assert isinstance(event, LLMCallCompletedEvent)
assert event.usage == usage_data
def test_none_usage_is_passed_to_event(self, mock_emit, llm):
llm._emit_call_completed_event(
response="hello",
call_type=LLMCallType.LLM_CALL,
messages="test prompt",
usage=None,
)
mock_emit.assert_called_once()
event = mock_emit.call_args[1]["event"]
assert isinstance(event, LLMCallCompletedEvent)
assert event.usage is None
def test_usage_omitted_defaults_to_none(self, mock_emit, llm):
llm._emit_call_completed_event(
response="hello",
call_type=LLMCallType.LLM_CALL,
messages="test prompt",
)
mock_emit.assert_called_once()
event = mock_emit.call_args[1]["event"]
assert isinstance(event, LLMCallCompletedEvent)
assert event.usage is None

View File

@@ -132,12 +132,12 @@ def test_embedding_configuration_flow(
embedder_config = {
"provider": "sentence-transformer",
"model_name": "all-MiniLM-L6-v2",
"config": {"model_name": "all-MiniLM-L6-v2"},
}
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
mock_get_embedding.assert_called_once_with(embedder_config)
mock_get_embedding.assert_called_once_with(storage.embedder)
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")

View File

@@ -125,8 +125,8 @@ def test_anthropic_specific_parameters():
assert isinstance(llm, AnthropicCompletion)
assert llm.stop_sequences == ["Human:", "Assistant:"]
assert llm.stream == True
assert llm.client.max_retries == 5
assert llm.client.timeout == 60
assert llm._client.max_retries == 5
assert llm._client.timeout == 60
def test_anthropic_completion_call():
@@ -563,8 +563,8 @@ def test_anthropic_environment_variable_api_key():
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
assert llm.client is not None
assert hasattr(llm.client, 'messages')
assert llm._client is not None
assert hasattr(llm._client, 'messages')
def test_anthropic_token_usage_tracking():
@@ -574,7 +574,7 @@ def test_anthropic_token_usage_tracking():
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
# Mock the Anthropic response with usage information
with patch.object(llm.client.messages, 'create') as mock_create:
with patch.object(llm._client.messages, 'create') as mock_create:
mock_response = MagicMock()
mock_response.content = [MagicMock(text="test response")]
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
@@ -639,14 +639,14 @@ def test_anthropic_thinking():
assert isinstance(llm, AnthropicCompletion)
original_create = llm.client.messages.create
original_create = llm._client.messages.create
captured_params = {}
def capture_and_call(**kwargs):
captured_params.update(kwargs)
return original_create(**kwargs)
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
result = llm.call("What is the weather in Tokyo?")
assert result is not None
@@ -677,14 +677,14 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
assert isinstance(llm, AnthropicCompletion)
# Capture all messages.create calls to verify thinking blocks are included
original_create = llm.client.messages.create
original_create = llm._client.messages.create
captured_calls = []
def capture_and_call(**kwargs):
captured_calls.append(kwargs)
return original_create(**kwargs)
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
# First call - establishes context and generates thinking blocks
messages = [{"role": "user", "content": "What is 2+2?"}]
first_result = llm.call(messages)
@@ -695,8 +695,8 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
assert len(first_result) > 0
# Verify thinking blocks were stored after first response
assert len(llm.previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
first_thinking = llm.previous_thinking_blocks[0]
assert len(llm._previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
first_thinking = llm._previous_thinking_blocks[0]
assert first_thinking["type"] == "thinking"
assert "thinking" in first_thinking
assert "signature" in first_thinking

View File

@@ -66,7 +66,7 @@ def test_azure_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Azure client responses
with patch.object(completion.client, 'complete') as mock_complete:
with patch.object(completion._client, 'complete') as mock_complete:
# Mock tool call in response with proper type
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
mock_tool_call.function.name = "get_weather"
@@ -698,7 +698,7 @@ def test_azure_environment_variable_endpoint():
}):
llm = LLM(model="azure/gpt-4")
assert llm.client is not None
assert llm._client is not None
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
@@ -709,7 +709,7 @@ def test_azure_token_usage_tracking():
llm = LLM(model="azure/gpt-4")
# Mock the Azure response with usage information
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_message = MagicMock()
mock_message.content = "test response"
mock_message.tool_calls = None
@@ -747,7 +747,7 @@ def test_azure_http_error_handling():
llm = LLM(model="azure/gpt-4")
# Mock an HTTP error
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
with pytest.raises(HttpResponseError):
@@ -966,7 +966,7 @@ def test_azure_improved_error_messages():
llm = LLM(model="azure/gpt-4")
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
error_401 = HttpResponseError(message="Unauthorized")
error_401.status_code = 401
mock_complete.side_effect = error_401
@@ -1327,7 +1327,7 @@ def test_azure_stop_words_not_applied_to_structured_output():
# Without the fix, this would be truncated at "Observation:" breaking the JSON
json_response = '{"finding": "The data shows growth", "observation": "Observation: This confirms the hypothesis"}'
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_message = MagicMock()
mock_message.content = json_response
mock_message.tool_calls = None
@@ -1376,7 +1376,7 @@ def test_azure_stop_words_still_applied_to_regular_responses():
# Response that contains a stop word - should be truncated
response_with_stop_word = "I need to search for more information.\n\nAction: search\nObservation: Found results"
with patch.object(llm.client, 'complete') as mock_complete:
with patch.object(llm._client, 'complete') as mock_complete:
mock_message = MagicMock()
mock_message.content = response_with_stop_word
mock_message.tool_calls = None

View File

@@ -674,7 +674,7 @@ def test_bedrock_token_usage_tracking():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Mock the Bedrock response with usage information
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {
@@ -719,7 +719,7 @@ def test_bedrock_tool_use_conversation_flow():
available_functions = {"get_weather": mock_weather_tool}
# Mock the Bedrock client responses
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
# First response: tool use request
tool_use_response = {
'output': {
@@ -805,7 +805,7 @@ def test_bedrock_client_error_handling():
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
# Test ValidationException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ValidationException',
@@ -819,7 +819,7 @@ def test_bedrock_client_error_handling():
assert "validation" in str(exc_info.value).lower()
# Test ThrottlingException
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
error_response = {
'Error': {
'Code': 'ThrottlingException',
@@ -861,7 +861,7 @@ def test_bedrock_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client, 'converse') as mock_converse:
with patch.object(llm._client, 'converse') as mock_converse:
mock_response = {
'output': {
'message': {

View File

@@ -556,8 +556,8 @@ def test_gemini_environment_variable_api_key():
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
llm = LLM(model="google/gemini-2.0-flash-001")
assert llm.client is not None
assert hasattr(llm.client, 'models')
assert llm._client is not None
assert hasattr(llm._client, 'models')
assert llm.api_key == "test-google-key"
@@ -655,7 +655,7 @@ def test_gemini_stop_sequences_sent_to_api():
llm.stop = ["\nObservation:", "\nThought:"]
# Patch the API call to capture parameters without making real call
with patch.object(llm.client.models, 'generate_content') as mock_generate:
with patch.object(llm._client.models, 'generate_content') as mock_generate:
mock_response = MagicMock()
mock_response.text = "Hello"
mock_response.candidates = []

View File

@@ -371,11 +371,11 @@ def test_openai_client_setup_with_extra_arguments():
assert llm.top_p == 0.5
# Check that client parameters are properly configured
assert llm.client.max_retries == 3
assert llm.client.timeout == 30
assert llm._client.max_retries == 3
assert llm._client.timeout == 30
# Test that parameters are properly used in API calls
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -396,7 +396,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
"""
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
with patch.object(llm.client.chat.completions, 'create') as mock_create:
with patch.object(llm._client.chat.completions, 'create') as mock_create:
mock_create.return_value = MagicMock(
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
@@ -507,7 +507,7 @@ def test_openai_streaming_with_response_model():
llm = LLM(model="openai/gpt-4o", stream=True)
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
with patch.object(llm._client.beta.chat.completions, "stream") as mock_stream:
# Create mock chunks with content.delta event structure
mock_chunk1 = MagicMock()
mock_chunk1.type = "content.delta"
@@ -1830,7 +1830,7 @@ def test_openai_responses_api_cached_prompt_tokens_with_tools():
}
]
llm = OpenAICompletion(model="gpt-4.1", api='response')
llm = OpenAICompletion(model="gpt-4.1", api='responses')
# First call with tool
llm.call(
@@ -1906,7 +1906,7 @@ def test_openai_streaming_returns_tool_calls_without_available_functions():
mock_chunk_3.id = "chatcmpl-1"
with patch.object(
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
llm._client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
):
result = llm.call(
messages=[{"role": "user", "content": "Calculate 1+1"}],
@@ -1997,7 +1997,7 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
with patch.object(
llm.async_client.chat.completions, "create", side_effect=mock_create
llm._async_client.chat.completions, "create", side_effect=mock_create
):
result = await llm.acall(
messages=[{"role": "user", "content": "Calculate 1+1"}],

View File

@@ -3,6 +3,8 @@
from unittest.mock import MagicMock, patch
import pytest
from pydantic import ValidationError
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
KnowledgeStorage,
)
@@ -59,7 +61,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
"Unsupported provider: invalid_provider"
)
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
with pytest.raises(ValidationError):
KnowledgeStorage(
embedder={"provider": "invalid_provider"},
collection_name="invalid_embedding_test",

View File

@@ -873,7 +873,7 @@ class TestAutoPersistence:
# Create flow WITHOUT persistence
flow = TestFlow()
assert flow._persistence is None # No persistence initially
assert flow.persistence is None # No persistence initially
# kickoff should auto-create persistence when HumanFeedbackPending is raised
result = flow.kickoff()
@@ -882,11 +882,11 @@ class TestAutoPersistence:
assert isinstance(result, HumanFeedbackPending)
# Persistence should have been auto-created
assert flow._persistence is not None
assert flow.persistence is not None
# The pending feedback should be saved
flow_id = result.context.flow_id
loaded = flow._persistence.load_pending_feedback(flow_id)
loaded = flow.persistence.load_pending_feedback(flow_id)
assert loaded is not None

View File

@@ -752,11 +752,7 @@ def test_litellm_retry_catches_litellm_unsupported_params_error(caplog):
raise litellm_error
return MagicMock(
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
usage=MagicMock(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
with patch("litellm.completion", side_effect=mock_completion):
@@ -787,11 +783,7 @@ def test_litellm_retry_catches_openai_api_stop_error(caplog):
raise api_error
return MagicMock(
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
usage=MagicMock(
prompt_tokens=10,
completion_tokens=5,
total_tokens=15,
),
usage={"prompt_tokens": 10, "completion_tokens": 5, "total_tokens": 15},
)
with patch("litellm.completion", side_effect=mock_completion):

View File

@@ -1,5 +1,5 @@
from typing import Any, ClassVar
from unittest.mock import Mock, patch
from unittest.mock import Mock, create_autospec, patch
import pytest
from crewai.agent import Agent
@@ -372,8 +372,11 @@ def test_internal_crew_with_mcp():
mock_adapter = Mock()
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
mock_llm = Mock()
mock_llm.__class__ = BaseLLM
class _StubLLM(BaseLLM):
def call(self, *a: Any, **kw: Any) -> str:
return ""
mock_llm = create_autospec(_StubLLM(model="stub"), instance=True)
with (
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,

View File

@@ -879,6 +879,35 @@ def test_llm_emits_call_started_event():
assert started_events[0].task_id is None
@pytest.mark.vcr()
def test_llm_completed_event_includes_usage():
completed_events: list[LLMCallCompletedEvent] = []
condition = threading.Condition()
@crewai_event_bus.on(LLMCallCompletedEvent)
def handle_llm_call_completed(source, event):
with condition:
completed_events.append(event)
condition.notify()
llm = LLM(model="gpt-4o-mini")
llm.call("Say hello")
with condition:
success = condition.wait_for(
lambda: len(completed_events) >= 1,
timeout=10,
)
assert success, "Timeout waiting for LLMCallCompletedEvent"
event = completed_events[0]
assert event.usage is not None
assert isinstance(event.usage, dict)
assert event.usage.get("prompt_tokens", 0) > 0
assert event.usage.get("completion_tokens", 0) > 0
assert event.usage.get("total_tokens", 0) > 0
@pytest.mark.vcr()
def test_llm_emits_call_failed_event():
received_events = []

View File

@@ -1,3 +1,3 @@
"""CrewAI development tools."""
__version__ = "1.13.0rc1"
__version__ = "1.13.0a5"