mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-31 08:08:17 +00:00
Compare commits
8 Commits
docs/file-
...
gl/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd059a395 | ||
|
|
297f7a0426 | ||
|
|
dfc0f9a317 | ||
|
|
ef79456968 | ||
|
|
6c7ea422e7 | ||
|
|
bb9bcd6823 | ||
|
|
ac14b9127e | ||
|
|
98b7626784 |
@@ -6,6 +6,7 @@ import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -19,6 +20,9 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
CrewAgentExecutor.model_rebuild()
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
original_warn = warnings.warn
|
||||
|
||||
@@ -25,7 +25,6 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
@@ -167,10 +166,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
function_calling_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -185,7 +184,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
|
||||
knowledge_storage: BaseKnowledgeStorage | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
|
||||
@@ -14,8 +14,15 @@ import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
@@ -23,6 +30,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.core.providers.human_input import ExecutorContext, get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
@@ -38,6 +46,9 @@ from crewai.hooks.tool_hooks import (
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
@@ -59,106 +70,65 @@ from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
"""Executor for crew agents.
|
||||
|
||||
Manages the execution lifecycle of an agent including prompt formatting,
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
llm: BaseLLM
|
||||
task: Task | None = None
|
||||
crew: Crew | None = None
|
||||
agent: Agent
|
||||
prompt: SystemPromptResult | StandardPromptResult
|
||||
max_iter: int
|
||||
tools: list[CrewStructuredTool]
|
||||
tools_names: str
|
||||
stop: list[str] = Field(alias="stop_words")
|
||||
tools_description: str
|
||||
tools_handler: ToolsHandler
|
||||
step_callback: Any = None
|
||||
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||
function_calling_llm: BaseLLM | Any | None = None
|
||||
respect_context_window: bool = False
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: type[BaseModel] | None = None
|
||||
i18n: I18N | None = Field(default=None, exclude=True)
|
||||
ask_for_human_input: bool = False
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
iterations: int = 0
|
||||
log_error_after: int = 3
|
||||
before_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
_i18n: I18N = PrivateAttr()
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_executor(self) -> Self:
|
||||
self._i18n = self.i18n or get_i18n()
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -171,6 +141,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -1687,14 +1658,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -73,6 +73,7 @@ class PlusAPI:
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
tools_metadata: list[dict[str, Any]] | None = None,
|
||||
) -> httpx.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
@@ -81,6 +82,9 @@ class PlusAPI:
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata}
|
||||
if tools_metadata is not None
|
||||
else None,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
extract_available_exports,
|
||||
extract_tools_metadata,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
get_project_version,
|
||||
@@ -101,6 +102,18 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(
|
||||
f"[green]Found these tools to publish: {', '.join([e['name'] for e in available_exports])}[/green]"
|
||||
)
|
||||
|
||||
console.print("[bold blue]Extracting tool metadata...[/bold blue]")
|
||||
try:
|
||||
tools_metadata = extract_tools_metadata()
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract tool metadata: {e}[/yellow]\n"
|
||||
f"Publishing will continue without detailed metadata."
|
||||
)
|
||||
tools_metadata = []
|
||||
|
||||
self._print_tools_preview(tools_metadata)
|
||||
self._print_current_organization()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
@@ -118,7 +131,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"Project build failed. Please ensure that the command `uv build --sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
raise SystemExit(1)
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
@@ -134,6 +147,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
self._validate_response(publish_response)
|
||||
@@ -246,6 +260,55 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _print_tools_preview(self, tools_metadata: list[dict[str, Any]]) -> None:
|
||||
if not tools_metadata:
|
||||
console.print("[yellow]No tool metadata extracted.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"\n[bold]Tools to be published ({len(tools_metadata)}):[/bold]\n"
|
||||
)
|
||||
|
||||
for tool in tools_metadata:
|
||||
console.print(f" [bold cyan]{tool.get('name', 'Unknown')}[/bold cyan]")
|
||||
if tool.get("module"):
|
||||
console.print(f" Module: {tool.get('module')}")
|
||||
console.print(f" Name: {tool.get('humanized_name', 'N/A')}")
|
||||
console.print(
|
||||
f" Description: {tool.get('description', 'N/A')[:80]}{'...' if len(tool.get('description', '')) > 80 else ''}"
|
||||
)
|
||||
|
||||
init_params = tool.get("init_params_schema", {}).get("properties", {})
|
||||
if init_params:
|
||||
required = tool.get("init_params_schema", {}).get("required", [])
|
||||
console.print(" Init parameters:")
|
||||
for param_name, param_info in init_params.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
is_required = param_name in required
|
||||
req_marker = "[red]*[/red]" if is_required else ""
|
||||
default = (
|
||||
f" = {param_info['default']}" if "default" in param_info else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {param_name}: {param_type}{default} {req_marker}"
|
||||
)
|
||||
|
||||
env_vars = tool.get("env_vars", [])
|
||||
if env_vars:
|
||||
console.print(" Environment variables:")
|
||||
for env_var in env_vars:
|
||||
req_marker = "[red]*[/red]" if env_var.get("required") else ""
|
||||
default = (
|
||||
f" (default: {env_var['default']})"
|
||||
if env_var.get("default")
|
||||
else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {env_var['name']}: {env_var.get('description', 'N/A')}{default} {req_marker}"
|
||||
)
|
||||
|
||||
console.print()
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from functools import reduce
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache, reduce
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, cast, get_type_hints
|
||||
|
||||
import click
|
||||
@@ -544,43 +549,62 @@ def build_env_with_tool_repository_credentials(
|
||||
return env
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _load_module_from_file(
|
||||
init_file: Path, module_name: str | None = None
|
||||
) -> Generator[types.ModuleType | None, None, None]:
|
||||
"""
|
||||
Context manager for loading a module from file with automatic cleanup.
|
||||
|
||||
Yields the loaded module or None if loading fails.
|
||||
"""
|
||||
if module_name is None:
|
||||
module_name = (
|
||||
f"temp_module_{hashlib.sha256(str(init_file).encode()).hexdigest()[:8]}"
|
||||
)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, init_file)
|
||||
if not spec or not spec.loader:
|
||||
yield None
|
||||
return
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load and validate tools from a given __init__.py file.
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location("temp_module", init_file)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
return []
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["temp_module"] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
with _load_module_from_file(init_file) as module:
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
if not hasattr(module, "__all__"):
|
||||
console.print(
|
||||
f"Warning: No __all__ defined in {init_file}",
|
||||
style="bold yellow",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
for name in module.__all__
|
||||
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
||||
]
|
||||
if not hasattr(module, "__all__"):
|
||||
console.print(
|
||||
f"Warning: No __all__ defined in {init_file}",
|
||||
style="bold yellow",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return [
|
||||
{"name": name}
|
||||
for name in module.__all__
|
||||
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
||||
]
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
finally:
|
||||
sys.modules.pop("temp_module", None)
|
||||
|
||||
|
||||
def _print_no_tools_warning() -> None:
|
||||
"""
|
||||
@@ -610,3 +634,242 @@ def _print_no_tools_warning() -> None:
|
||||
" # ... implementation\n"
|
||||
" return result\n"
|
||||
)
|
||||
|
||||
|
||||
def extract_tools_metadata(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract rich metadata from tool classes in the project.
|
||||
|
||||
Returns a list of tool metadata dictionaries containing:
|
||||
- name: Class name
|
||||
- humanized_name: From name field default
|
||||
- description: From description field default
|
||||
- run_params_schema: JSON Schema for _run() params (from args_schema)
|
||||
- init_params_schema: JSON Schema for __init__ params (filtered)
|
||||
- env_vars: List of environment variable dicts
|
||||
"""
|
||||
tools_metadata: list[dict[str, Any]] = []
|
||||
|
||||
for init_file in Path(dir_path).glob("**/__init__.py"):
|
||||
tools = _extract_tool_metadata_from_init(init_file)
|
||||
tools_metadata.extend(tools)
|
||||
|
||||
return tools_metadata
|
||||
|
||||
|
||||
def _extract_tool_metadata_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load module from init file and extract metadata from valid tool classes.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
try:
|
||||
with _load_module_from_file(init_file) as module:
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
exported_names = getattr(module, "__all__", None)
|
||||
if not exported_names:
|
||||
return []
|
||||
|
||||
tools_metadata = []
|
||||
for name in exported_names:
|
||||
obj = getattr(module, name, None)
|
||||
if obj is None or not (
|
||||
inspect.isclass(obj) and issubclass(obj, BaseTool)
|
||||
):
|
||||
continue
|
||||
if tool_info := _extract_single_tool_metadata(obj):
|
||||
tools_metadata.append(tool_info)
|
||||
|
||||
return tools_metadata
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract metadata from {init_file}: {e}[/yellow]"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _extract_single_tool_metadata(tool_class: type) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract metadata from a single tool class.
|
||||
"""
|
||||
try:
|
||||
core_schema = cast(Any, tool_class).__pydantic_core_schema__
|
||||
if not core_schema:
|
||||
return None
|
||||
|
||||
schema = _unwrap_schema(core_schema)
|
||||
fields = schema.get("schema", {}).get("fields", {})
|
||||
|
||||
try:
|
||||
file_path = inspect.getfile(tool_class)
|
||||
relative_path = Path(file_path).relative_to(Path.cwd())
|
||||
module_path = relative_path.with_suffix("")
|
||||
if module_path.parts[0] == "src":
|
||||
module_path = Path(*module_path.parts[1:])
|
||||
if module_path.name == "__init__":
|
||||
module_path = module_path.parent
|
||||
module = ".".join(module_path.parts)
|
||||
except (TypeError, ValueError):
|
||||
module = tool_class.__module__
|
||||
|
||||
return {
|
||||
"name": tool_class.__name__,
|
||||
"module": module,
|
||||
"humanized_name": _extract_field_default(
|
||||
fields.get("name"), fallback=tool_class.__name__
|
||||
),
|
||||
"description": str(
|
||||
_extract_field_default(fields.get("description"))
|
||||
).strip(),
|
||||
"run_params_schema": _extract_run_params_schema(fields.get("args_schema")),
|
||||
"init_params_schema": _extract_init_params_schema(tool_class),
|
||||
"env_vars": _extract_env_vars(fields.get("env_vars")),
|
||||
}
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _unwrap_schema(schema: Mapping[str, Any] | dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Unwrap nested schema structures to get to the actual schema definition.
|
||||
"""
|
||||
result: dict[str, Any] = dict(schema)
|
||||
while (
|
||||
result.get("type")
|
||||
in {"function-after", "function-before", "function-wrap", "default"}
|
||||
and "schema" in result
|
||||
):
|
||||
result = dict(result["schema"])
|
||||
if result.get("type") == "definitions" and "schema" in result:
|
||||
result = dict(result["schema"])
|
||||
return result
|
||||
|
||||
|
||||
def _extract_field_default(
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
"""
|
||||
Extract the default value from a field schema.
|
||||
"""
|
||||
if not field:
|
||||
return fallback
|
||||
|
||||
schema = field.get("schema", {})
|
||||
default = schema.get("default")
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_schema_generator() -> type:
|
||||
"""Get a SchemaGenerator that omits non-serializable defaults."""
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
return SchemaGenerator
|
||||
|
||||
|
||||
def _extract_run_params_schema(
|
||||
args_schema_field: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's run parameters from args_schema field.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
args_schema_class = args_schema_field.get("schema", {}).get("default")
|
||||
if not (
|
||||
inspect.isclass(args_schema_class) and issubclass(args_schema_class, BaseModel)
|
||||
):
|
||||
return {}
|
||||
|
||||
try:
|
||||
return args_schema_class.model_json_schema(
|
||||
schema_generator=_get_schema_generator()
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
_IGNORED_INIT_PARAMS = frozenset(
|
||||
{
|
||||
"name",
|
||||
"description",
|
||||
"env_vars",
|
||||
"args_schema",
|
||||
"description_updated",
|
||||
"cache_function",
|
||||
"result_as_answer",
|
||||
"max_usage_count",
|
||||
"current_usage_count",
|
||||
"package_dependencies",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _extract_init_params_schema(tool_class: type) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's __init__ parameters, filtering out base fields.
|
||||
"""
|
||||
try:
|
||||
json_schema: dict[str, Any] = cast(Any, tool_class).model_json_schema(
|
||||
schema_generator=_get_schema_generator(), mode="serialization"
|
||||
)
|
||||
filtered_properties = {
|
||||
key: value
|
||||
for key, value in json_schema.get("properties", {}).items()
|
||||
if key not in _IGNORED_INIT_PARAMS
|
||||
}
|
||||
json_schema["properties"] = filtered_properties
|
||||
if "required" in json_schema:
|
||||
json_schema["required"] = [
|
||||
key for key in json_schema["required"] if key in filtered_properties
|
||||
]
|
||||
return json_schema
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_env_vars(env_vars_field: dict[str, Any] | None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract environment variable definitions from env_vars field.
|
||||
"""
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
schema = env_vars_field.get("schema", {})
|
||||
default = schema.get("default")
|
||||
if default is None:
|
||||
default_factory = schema.get("default_factory")
|
||||
if callable(default_factory):
|
||||
try:
|
||||
default = default_factory()
|
||||
except Exception:
|
||||
default = []
|
||||
|
||||
if not isinstance(default, list):
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"name": env_var.name,
|
||||
"description": env_var.description,
|
||||
"required": env_var.required,
|
||||
"default": env_var.default,
|
||||
}
|
||||
for env_var in default
|
||||
if isinstance(env_var, EnvVar)
|
||||
]
|
||||
|
||||
@@ -22,7 +22,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
Json,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
@@ -176,7 +175,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
@@ -210,13 +209,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
manager_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | None = Field(
|
||||
function_calling_llm: str | LLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
@@ -267,7 +266,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
planning_llm: str | BaseLLM | Any | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
@@ -288,7 +287,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
chat_llm: str | BaseLLM | Any | None = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -1800,7 +1799,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
eval_llm: str | BaseLLM,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
@@ -1966,37 +1966,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
"original_tool": original_tool,
|
||||
}
|
||||
|
||||
def _extract_tool_name(self, tool_call: Any) -> str:
|
||||
"""Extract tool name from various tool call formats."""
|
||||
if hasattr(tool_call, "function"):
|
||||
return sanitize_tool_name(tool_call.function.name)
|
||||
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
return sanitize_tool_name(tool_call.function_call.name)
|
||||
if hasattr(tool_call, "name"):
|
||||
return sanitize_tool_name(tool_call.name)
|
||||
if isinstance(tool_call, dict):
|
||||
func_info = tool_call.get("function", {})
|
||||
return sanitize_tool_name(
|
||||
func_info.get("name", "") or tool_call.get("name", "unknown")
|
||||
)
|
||||
return "unknown"
|
||||
|
||||
@router(execute_native_tool)
|
||||
def check_native_todo_completion(
|
||||
self,
|
||||
) -> Literal["todo_satisfied", "todo_not_satisfied"]:
|
||||
"""Check if the native tool execution satisfied the active todo.
|
||||
|
||||
Similar to check_todo_completion but for native tool execution path.
|
||||
"""
|
||||
current_todo = self.state.todos.current_todo
|
||||
|
||||
if not current_todo:
|
||||
return "todo_not_satisfied"
|
||||
|
||||
# For native tools, any tool execution satisfies the todo
|
||||
return "todo_satisfied"
|
||||
|
||||
@listen("initialized")
|
||||
def continue_iteration(self) -> Literal["check_iteration"]:
|
||||
"""Bridge listener that connects iteration loop back to iteration check."""
|
||||
|
||||
@@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
class BaseKnowledgeStorage(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""Abstract base class for knowledge storage implementations."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,6 +3,9 @@ import traceback
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
|
||||
from pydantic import Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
collection_name: str | None = None
|
||||
embedder: (
|
||||
ProviderSpec
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
| None
|
||||
) = Field(default=None, exclude=True)
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> Self:
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
if self.embedder:
|
||||
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
return self
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
|
||||
@@ -22,7 +22,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -204,7 +203,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
llm: str | BaseLLM | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: list[BaseTool] = Field(
|
||||
|
||||
@@ -20,8 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -37,7 +36,12 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context
|
||||
from crewai.llms.base_llm import (
|
||||
BaseLLM,
|
||||
JsonResponseFormat,
|
||||
get_current_call_id,
|
||||
llm_call_context,
|
||||
)
|
||||
from crewai.llms.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
@@ -63,8 +67,6 @@ except ImportError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -342,6 +344,27 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
timeout: float | int | None = None
|
||||
top_p: float | None = None
|
||||
n: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | float | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
logit_bias: dict[int, float] | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
seed: int | None = None
|
||||
logprobs: int | None = None
|
||||
top_logprobs: int | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
callbacks: list[Any] | None = None
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
|
||||
stream: bool = False
|
||||
interceptor: Any = None
|
||||
thinking: Any = None
|
||||
context_window_size: int = 0
|
||||
is_anthropic: bool = False
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM.
|
||||
@@ -436,10 +459,7 @@ class LLM(BaseLLM):
|
||||
logger.error(error_msg)
|
||||
raise ImportError(error_msg) from None
|
||||
|
||||
instance = object.__new__(cls)
|
||||
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
return object.__new__(cls)
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
@@ -624,89 +644,23 @@ class LLM(BaseLLM):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_llm_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
model = data.get("model", "")
|
||||
data["is_anthropic"] = cls._is_anthropic_model(model)
|
||||
return data
|
||||
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.prefer_upload = prefer_upload
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
@model_validator(mode="after")
|
||||
def _init_litellm(self) -> LLM:
|
||||
self.is_litellm = True
|
||||
if LITELLM_AVAILABLE:
|
||||
litellm.drop_params = True
|
||||
self.set_callbacks(self.callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
@@ -753,7 +707,7 @@ class LLM(BaseLLM):
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop or None,
|
||||
"stop": (self.stop or None) if self.supports_stop_words() else None,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
@@ -1825,9 +1779,11 @@ class LLM(BaseLLM):
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
error_str = str(e)
|
||||
unsupported_stop = "'stop'" in error_str and (
|
||||
"Unsupported parameter" in error_str
|
||||
or "does not support parameters" in error_str
|
||||
)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
@@ -1961,9 +1917,11 @@ class LLM(BaseLLM):
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
error_str = str(e)
|
||||
unsupported_stop = "'stop'" in error_str and (
|
||||
"Unsupported parameter" in error_str
|
||||
or "does not support parameters" in error_str
|
||||
)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
@@ -2263,6 +2221,10 @@ class LLM(BaseLLM):
|
||||
Note: This method is only used by the litellm fallback path.
|
||||
Native providers override this method with their own implementation.
|
||||
"""
|
||||
model_lower = self.model.lower() if self.model else ""
|
||||
if "gpt-5" in model_lower:
|
||||
return False
|
||||
|
||||
if not LITELLM_AVAILABLE or get_supported_openai_params is None:
|
||||
# When litellm is not available, assume stop words are supported
|
||||
return True
|
||||
@@ -2434,7 +2396,7 @@ class LLM(BaseLLM):
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM:
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
|
||||
@@ -14,10 +14,18 @@ from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -51,6 +59,12 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class JsonResponseFormat(TypedDict):
|
||||
"""Response format requesting raw JSON output (e.g. ``{"type": "json_object"}``)."""
|
||||
|
||||
type: Literal["json_object"]
|
||||
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
@@ -82,7 +96,7 @@ def get_current_call_id() -> str:
|
||||
return call_id
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
@@ -101,56 +115,100 @@ class BaseLLM(ABC):
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
model: str
|
||||
temperature: float | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
provider: str = Field(default="openai")
|
||||
prefer_upload: bool = False
|
||||
is_litellm: bool = False
|
||||
stop: list[str] = Field(
|
||||
default_factory=list,
|
||||
validation_alias=AliasChoices("stop", "stop_sequences"),
|
||||
)
|
||||
additional_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
provider: str | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("stop", "stop_sequences"):
|
||||
if value is None:
|
||||
value = []
|
||||
elif isinstance(value, str):
|
||||
value = [value]
|
||||
elif not isinstance(value, list):
|
||||
value = list(value)
|
||||
name = "stop"
|
||||
try:
|
||||
super().__setattr__(name, value)
|
||||
except ValueError:
|
||||
if name in self.model_fields:
|
||||
raise # Re-raise validation errors on declared fields
|
||||
# Fallback for attributes not declared as fields (e.g. mock patching)
|
||||
object.__setattr__(self, name, value)
|
||||
except AttributeError:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
prefer_upload: Whether to prefer file upload over inline base64.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
def __delattr__(self, name: str) -> None:
|
||||
try:
|
||||
super().__delattr__(name)
|
||||
except AttributeError:
|
||||
object.__delattr__(self, name)
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Alias for ``stop`` — kept for backward compatibility with provider APIs.
|
||||
|
||||
Writes are handled by ``__setattr__``, which normalizes and redirects
|
||||
``stop_sequences`` assignments to the ``stop`` field.
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
return self.stop
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.prefer_upload = prefer_upload
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
self.stop = stop
|
||||
else:
|
||||
self.stop = []
|
||||
|
||||
self._token_usage = {
|
||||
_token_usage: dict[str, int] = PrivateAttr(
|
||||
default_factory=lambda: {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_init_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if not data.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
# Normalize stop: accept str, list, or None; also accept stop_sequences alias
|
||||
stop_seqs = data.pop("stop_sequences", None)
|
||||
stop = stop_seqs if stop_seqs is not None else data.get("stop")
|
||||
if stop is None:
|
||||
data["stop"] = []
|
||||
elif isinstance(stop, str):
|
||||
data["stop"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
data["stop"] = stop
|
||||
else:
|
||||
data["stop"] = list(stop)
|
||||
|
||||
# Default provider
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
|
||||
# Collect unknown kwargs into additional_params
|
||||
known_fields = set(cls.model_fields.keys())
|
||||
extras = {k: v for k, v in data.items() if k not in known_fields}
|
||||
for k in extras:
|
||||
data.pop(k)
|
||||
existing = data.get("additional_params") or {}
|
||||
existing.update(extras)
|
||||
data["additional_params"] = existing
|
||||
|
||||
return data
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
|
||||
@@ -174,16 +232,6 @@ class BaseLLM(ABC):
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
|
||||
@@ -3,12 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||
from typing import Any, Final, Literal, TypeGuard, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -17,9 +18,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import (
|
||||
@@ -150,60 +148,47 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
model: str = "claude-3-5-sonnet-20241022"
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
max_tokens: int = 4096
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
client_params: dict[str, Any] | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
thinking: AnthropicThinkingConfig | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
tool_search: AnthropicToolSearchConfig | None = None
|
||||
is_claude_3: bool = False
|
||||
supports_tools: bool = True
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_previous_thinking_blocks: list[Any] = PrivateAttr(default_factory=list)
|
||||
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_anthropic_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
# Anthropic uses stop_sequences; normalize from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
data["is_claude_3"] = "claude-3" in data.get("model", "").lower()
|
||||
# Normalize tool_search
|
||||
ts = data.get("tool_search")
|
||||
if ts is True:
|
||||
data["tool_search"] = AnthropicToolSearchConfig()
|
||||
elif ts is not None and not isinstance(ts, AnthropicToolSearchConfig):
|
||||
data["tool_search"] = None
|
||||
return data
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AnthropicCompletion:
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -211,51 +196,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
self.tool_search: AnthropicToolSearchConfig | None
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Anthropic API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
self._async_client = AsyncAnthropic(**async_client_params)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
@@ -751,11 +693,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
elif self.thinking and self.previous_thinking_blocks:
|
||||
elif self.thinking and self._previous_thinking_blocks:
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
*self.previous_thinking_blocks,
|
||||
*self._previous_thinking_blocks,
|
||||
{"type": "text", "text": content if content else ""},
|
||||
],
|
||||
)
|
||||
@@ -809,7 +751,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -843,11 +785,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self.client.beta.messages.create(
|
||||
response = self._client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self.client.messages.create(**params)
|
||||
response = self._client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -928,7 +870,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
self._emit_call_completed_event(
|
||||
@@ -952,7 +894,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -991,9 +933,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
if betas
|
||||
else self.client.messages.stream(**stream_params)
|
||||
else self._client.messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1072,7 +1014,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
@@ -1269,7 +1211,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -1288,7 +1230,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
@@ -1330,7 +1272,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -1364,11 +1306,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self.async_client.beta.messages.create(
|
||||
response = await self._async_client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self.async_client.messages.create(**params)
|
||||
response = await self._async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -1461,7 +1403,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -1498,11 +1440,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self.async_client.beta.messages.stream(
|
||||
self._async_client.beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self.async_client.messages.stream(**stream_params)
|
||||
else self._async_client.messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1664,7 +1606,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
final_response: Message = await self._async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
@@ -1786,8 +1728,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -3,11 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -16,10 +18,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
@@ -76,109 +74,84 @@ class AzureCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
endpoint: str | None = None
|
||||
api_version: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stream: bool = False
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_openai_model: bool = False
|
||||
is_azure_openai_endpoint: bool = False
|
||||
|
||||
Args:
|
||||
model: Azure deployment name or model name
|
||||
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
||||
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
||||
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_azure_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop or [], **kwargs
|
||||
)
|
||||
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
# Resolve env vars
|
||||
data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY")
|
||||
data["endpoint"] = (
|
||||
data.get("endpoint")
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
data["api_version"] = (
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
if not data["api_key"]:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
if not data["endpoint"]:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
model = data.get("model", "")
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
data["is_openai_model"] = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
return data
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AzureCompletion:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key is required.")
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
@@ -215,7 +188,11 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
ep_host = urlparse(endpoint).hostname or ""
|
||||
is_azure_openai = ep_host == "openai.azure.com" or ep_host.endswith(
|
||||
".openai.azure.com"
|
||||
)
|
||||
if is_azure_openai and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
@@ -731,7 +708,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -926,7 +903,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
for update in self.client.complete(**params): # type: ignore[arg-type]
|
||||
for update in self._client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
@@ -967,7 +944,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
response: ChatCompletions = await self._async_client.complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -993,8 +970,8 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
|
||||
async for update in stream: # type: ignore[union-attr]
|
||||
stream = await self._async_client.complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
usage = update.usage
|
||||
@@ -1110,8 +1087,8 @@ class AzureCompletion(BaseLLM):
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
if hasattr(self._async_client, "close"):
|
||||
await self._async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
@@ -228,129 +228,97 @@ class BedrockCompletion(BaseLLM):
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region_name: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_session_token: str | None = None
|
||||
region_name: str | None = None
|
||||
max_tokens: int | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stream: bool = False
|
||||
guardrail_config: dict[str, Any] | None = None
|
||||
additional_model_request_fields: dict[str, Any] | None = None
|
||||
additional_model_response_field_paths: list[str] | None = None
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_claude_model: bool = False
|
||||
supports_tools: bool = True
|
||||
supports_streaming: bool = True
|
||||
model_id: str = ""
|
||||
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_exit_stack: Any = PrivateAttr(default=None)
|
||||
_async_client_initialized: bool = PrivateAttr(default=False)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_bedrock_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
# Force provider to bedrock
|
||||
data.pop("provider", None)
|
||||
data["provider"] = "bedrock"
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
stop=stop_sequences or [],
|
||||
provider="bedrock",
|
||||
**kwargs,
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
elif isinstance(seqs, Sequence) and not isinstance(seqs, list):
|
||||
seqs = list(seqs)
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv(
|
||||
"AWS_ACCESS_KEY_ID"
|
||||
)
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
"mode": "adaptive",
|
||||
},
|
||||
tcp_keepalive=True,
|
||||
data["aws_secret_access_key"] = data.get("aws_secret_access_key") or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
|
||||
self.region_name = (
|
||||
region_name
|
||||
data["aws_session_token"] = data.get("aws_session_token") or os.getenv(
|
||||
"AWS_SESSION_TOKEN"
|
||||
)
|
||||
data["region_name"] = (
|
||||
data.get("region_name")
|
||||
or os.getenv("AWS_DEFAULT_REGION")
|
||||
or os.getenv("AWS_REGION_NAME")
|
||||
or "us-east-1"
|
||||
)
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
model = data.get("model", "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
data["is_claude_model"] = "claude" in model.lower()
|
||||
data["model_id"] = model
|
||||
return data
|
||||
|
||||
# Initialize Bedrock client with proper configuration
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> BedrockCompletion:
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
session = Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences
|
||||
self.response_format = response_format
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
self.supports_streaming = True
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# NOTE: AWS credentials (access_key, secret_key, session_token) are
|
||||
# intentionally excluded — they must come from env on resume.
|
||||
if self.region_name and self.region_name != "us-east-1":
|
||||
config["region_name"] = self.region_name
|
||||
if self.max_tokens is not None:
|
||||
@@ -363,30 +331,6 @@ class BedrockCompletion(BaseLLM):
|
||||
config["guardrail_config"] = self.guardrail_config
|
||||
return config
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return [] if self.stop_sequences is None else list(self.stop_sequences)
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: Sequence[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Bedrock API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a Sequence, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, Sequence):
|
||||
self.stop_sequences = list(value)
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -710,7 +654,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
response = self._client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -994,13 +938,13 @@ class BedrockCompletion(BaseLLM):
|
||||
accumulated_tool_input = ""
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
response = self._client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body, # type: ignore[arg-type]
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
|
||||
@@ -5,12 +5,13 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -19,10 +20,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -44,137 +41,84 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
thinking_config: types.ThinkingConfig | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
model: str = "gemini-2.0-flash-001"
|
||||
project: str | None = None
|
||||
location: str | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
stream: bool = False
|
||||
safety_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
client_params: dict[str, Any] = Field(default_factory=dict)
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
use_vertexai: bool = False
|
||||
response_format: type[BaseModel] | None = None
|
||||
thinking_config: Any = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
supports_tools: bool = False
|
||||
is_gemini_2_0: bool = False
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key for Gemini API authentication.
|
||||
Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var.
|
||||
NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead.
|
||||
project: Google Cloud project ID for Vertex AI with ADC authentication.
|
||||
Requires Application Default Credentials (gcloud auth application-default login).
|
||||
NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC.
|
||||
If both api_key and project are set, api_key takes precedence.
|
||||
location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
use_vertexai: Whether to use Vertex AI instead of Gemini API.
|
||||
- True: Use Vertex AI (with ADC or Express mode with API key)
|
||||
- False: Use Gemini API (explicitly override env var)
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
thinking_config: ThinkingConfig for thinking models (gemini-2.5+, gemini-3+).
|
||||
Controls thought output via include_thoughts, thinking_budget,
|
||||
and thinking_level. When None, thinking models automatically
|
||||
get include_thoughts=True so thought content is surfaced.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_gemini_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = (
|
||||
data.get("api_key")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
data["project"] = data.get("project") or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
data["location"] = (
|
||||
data.get("location") or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
)
|
||||
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
if use_vertexai is None:
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
use_vx = data.get("use_vertexai")
|
||||
if use_vx is None:
|
||||
use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
data["use_vertexai"] = use_vx
|
||||
|
||||
# Model-specific settings
|
||||
model = data.get("model", "gemini-2.0-flash-001")
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(
|
||||
data["supports_tools"] = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
self.is_gemini_2_0 = bool(
|
||||
data["is_gemini_2_0"] = bool(
|
||||
version_match and float(version_match.group(1)) >= 2.0
|
||||
)
|
||||
|
||||
self.thinking_config = thinking_config
|
||||
# Auto-enable thinking for gemini-2.5+
|
||||
if (
|
||||
self.thinking_config is None
|
||||
data.get("thinking_config") is None
|
||||
and version_match
|
||||
and float(version_match.group(1)) >= 2.5
|
||||
):
|
||||
self.thinking_config = types.ThinkingConfig(include_thoughts=True)
|
||||
data["thinking_config"] = types.ThinkingConfig(include_thoughts=True)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
return data
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Gemini API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> GeminiCompletion:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
@@ -283,8 +227,8 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
@@ -1152,7 +1096,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
response = self._client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1192,7 +1136,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
for chunk in self._client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1230,7 +1174,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1270,7 +1214,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
stream = await self._client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1474,6 +1418,6 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(client=self.client)
|
||||
return GeminiFileUploader(client=self._client)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -14,10 +14,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -29,7 +30,6 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -183,77 +183,69 @@ class OpenAICompletion(BaseLLM):
|
||||
"computer_use": "computer_use_preview",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
api: Literal["completions", "responses"] = "completions",
|
||||
instructions: str | None = None,
|
||||
store: bool | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
builtin_tools: list[str] | None = None,
|
||||
parse_tool_outputs: bool = False,
|
||||
auto_chain: bool = False,
|
||||
auto_chain_reasoning: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI completion client."""
|
||||
model: str = "gpt-4o"
|
||||
organization: str | None = None
|
||||
project: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
default_headers: dict[str, str] | None = None
|
||||
default_query: dict[str, Any] | None = None
|
||||
client_params: dict[str, Any] | None = None
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
seed: int | None = None
|
||||
stream: bool = False
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
api: Literal["completions", "responses"] = "completions"
|
||||
instructions: str | None = None
|
||||
store: bool | None = None
|
||||
previous_response_id: str | None = None
|
||||
include: list[str] | None = None
|
||||
builtin_tools: list[str] | None = None
|
||||
parse_tool_outputs: bool = False
|
||||
auto_chain: bool = False
|
||||
auto_chain_reasoning: bool = False
|
||||
api_base: str | None = None
|
||||
is_o1_model: bool = False
|
||||
is_gpt4_model: bool = False
|
||||
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_last_response_id: str | None = PrivateAttr(default=None)
|
||||
_last_reasoning_items: list[Any] | None = PrivateAttr(default=None)
|
||||
|
||||
self.interceptor = interceptor
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_openai_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY")
|
||||
# Extract api_base from kwargs if present
|
||||
if "api_base" not in data:
|
||||
data["api_base"] = None
|
||||
model = data.get("model", "gpt-4o")
|
||||
data["is_o1_model"] = "o1" in model.lower()
|
||||
data["is_gpt4_model"] = "gpt-4" in model.lower()
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> OpenAICompletion:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
self._client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -261,35 +253,8 @@ class OpenAICompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
# API selection and Responses API parameters
|
||||
self.api = api
|
||||
self.instructions = instructions
|
||||
self.store = store
|
||||
self.previous_response_id = previous_response_id
|
||||
self.include = include
|
||||
self.builtin_tools = builtin_tools
|
||||
self.parse_tool_outputs = parse_tool_outputs
|
||||
self.auto_chain = auto_chain
|
||||
self.auto_chain_reasoning = auto_chain_reasoning
|
||||
self._last_response_id: str | None = None
|
||||
self._last_reasoning_items: list[Any] | None = None
|
||||
self._async_client = AsyncOpenAI(**async_client_config)
|
||||
return self
|
||||
|
||||
@property
|
||||
def last_response_id(self) -> str | None:
|
||||
@@ -818,7 +783,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = self.client.responses.create(**params)
|
||||
response: Response = self._client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -950,7 +915,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle async non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = await self.async_client.responses.create(**params)
|
||||
response: Response = await self._async_client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -1081,7 +1046,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = self.client.responses.create(**params)
|
||||
stream = self._client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
@@ -1205,7 +1170,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = await self.async_client.responses.create(**params)
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
@@ -1595,7 +1560,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1618,7 +1583,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -1837,7 +1802,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self.client.beta.chat.completions.stream(
|
||||
with self._client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
@@ -1873,7 +1838,7 @@ class OpenAICompletion(BaseLLM):
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -1970,7 +1935,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
parsed_response = await self._async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1993,7 +1958,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
response: ChatCompletion = await self._async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -2111,7 +2076,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -2164,7 +2129,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
@@ -2245,6 +2210,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
model_lower = self.model.lower() if self.model else ""
|
||||
if "gpt-5" in model_lower:
|
||||
return False
|
||||
return not self.is_o1_model
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -2353,8 +2321,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -16,6 +16,8 @@ from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
|
||||
@@ -140,31 +142,13 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI-compatible completion client.
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _resolve_provider_config(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
Args:
|
||||
model: The model identifier.
|
||||
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
|
||||
api_key: Optional API key override. If not provided, uses the
|
||||
provider's configured environment variable.
|
||||
base_url: Optional base URL override. If not provided, uses the
|
||||
provider's configured default or environment variable.
|
||||
default_headers: Optional headers to merge with provider defaults.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported or required API key
|
||||
is missing.
|
||||
"""
|
||||
provider = data.get("provider", "")
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
||||
if config is None:
|
||||
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
|
||||
@@ -173,21 +157,15 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
resolved_api_key = self._resolve_api_key(api_key, config, provider)
|
||||
resolved_base_url = self._resolve_base_url(base_url, config, provider)
|
||||
resolved_headers = self._resolve_headers(default_headers, config)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=resolved_api_key,
|
||||
base_url=resolved_base_url,
|
||||
default_headers=resolved_headers,
|
||||
**kwargs,
|
||||
data["api_key"] = cls._resolve_api_key(data.get("api_key"), config, provider)
|
||||
data["base_url"] = cls._resolve_base_url(data.get("base_url"), config, provider)
|
||||
data["default_headers"] = cls._resolve_headers(
|
||||
data.get("default_headers"), config
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
api_key: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -220,8 +198,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return config.default_api_key
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_url(
|
||||
self,
|
||||
base_url: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -249,8 +227,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_headers(
|
||||
self,
|
||||
headers: dict[str, str] | None,
|
||||
config: ProviderConfig,
|
||||
) -> dict[str, str] | None:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Third-party LLM implementations for crewAI."""
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.box import HEAVY_EDGE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -39,9 +39,9 @@ class CrewEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
crew: Crew,
|
||||
eval_llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
eval_llm: BaseLLM | str | None = None,
|
||||
openai_model_name: str | None = None,
|
||||
llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
llm: BaseLLM | str | None = None,
|
||||
) -> None:
|
||||
self.crew = crew
|
||||
self.llm = eval_llm
|
||||
|
||||
@@ -1692,9 +1692,27 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
) as mock_knowledge_storage:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
class _StubStorage(BaseKnowledgeStorage):
|
||||
def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
def save(self, documents):
|
||||
pass
|
||||
|
||||
async def asave(self, documents):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def areset(self):
|
||||
pass
|
||||
|
||||
mock_knowledge_storage.return_value = _StubStorage()
|
||||
agent.knowledge_storage = _StubStorage()
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
|
||||
@@ -879,30 +879,6 @@ class TestNativeToolExecution:
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0]["tool_call_id"] == "call_1"
|
||||
|
||||
def test_check_native_todo_completion_requires_current_todo(
|
||||
self, mock_dependencies
|
||||
):
|
||||
from crewai.utilities.planning_types import TodoList
|
||||
|
||||
executor = AgentExecutor(**mock_dependencies)
|
||||
|
||||
# No current todo → not satisfied
|
||||
executor.state.todos = TodoList(items=[])
|
||||
assert executor.check_native_todo_completion() == "todo_not_satisfied"
|
||||
|
||||
# With a current todo that has tool_to_use → satisfied
|
||||
running = TodoItem(
|
||||
step_number=1,
|
||||
description="Use the expected tool",
|
||||
tool_to_use="expected_tool",
|
||||
status="running",
|
||||
)
|
||||
executor.state.todos = TodoList(items=[running])
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
# With a current todo without tool_to_use → still satisfied
|
||||
running.tool_to_use = None
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
|
||||
class TestPlannerObserver:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -68,13 +72,9 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -87,7 +87,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6158'
|
||||
- '6065'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -109,26 +109,113 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/responses
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXQCgT3p3ViImkiqDiZGqLREQtp\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747248,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_9ZqMavn3J1fBnQEaqpYol0Bd\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Tokyo\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
string: "{\n \"id\": \"resp_0d68149bcc0d14810069caf464a4b48197bd9f098abb2f6303\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908516,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908517,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0d68149bcc0d14810069caf46568088197a33be67f16a1fa09\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Tokyo\\\"}\",\n \"call_id\": \"call_74rwmYse0DE4JFaFGyAFx9bu\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 0\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -137,7 +224,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 10 Feb 2026 18:14:08 GMT
|
||||
- Mon, 30 Mar 2026 22:08:37 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -146,8 +233,6 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -155,15 +240,13 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '484'
|
||||
- '1085'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -182,8 +265,12 @@ interactions:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -250,13 +337,9 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -269,7 +352,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6158'
|
||||
- '6065'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
@@ -293,26 +376,113 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/responses
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXR8k9vk8TlGvGXlrQSI7iNeAN1\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747249,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_6PeUBlRPG8JcV2lspmLjJbnn\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Paris\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
string: "{\n \"id\": \"resp_0525bf798202137e0069caf465ee3c8196aa7c83da1c369eb7\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908517,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908518,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0525bf798202137e0069caf46666588196a2ec20dc515a6a91\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Paris\\\"}\",\n \"call_id\": \"call_LJAGuYYZPjNxSgg0TUgGpT44\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 1152\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -321,7 +491,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 10 Feb 2026 18:14:09 GMT
|
||||
- Mon, 30 Mar 2026 22:08:38 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -330,8 +500,6 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -339,15 +507,11 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '528'
|
||||
- '653'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"What is the capital of France?"}],"model":"gpt-5"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '89'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.2
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DO4LcSpy72yIXCYSIVOQEXWNXydgn\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1774628956,\n \"model\": \"gpt-5-2025-08-07\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Paris.\",\n \"refusal\": null,\n
|
||||
\ \"annotations\": []\n },\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 13,\n \"completion_tokens\":
|
||||
11,\n \"total_tokens\": 24,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": null\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9e2fc5dce85582fb-GIG
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 27 Mar 2026 16:29:17 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
content-length:
|
||||
- '772'
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '1343'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -136,6 +136,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
@@ -173,6 +174,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
@@ -201,6 +203,48 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_with_tools_metadata(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
available_exports = [{"name": "MyTool"}]
|
||||
tools_metadata = [
|
||||
{
|
||||
"name": "MyTool",
|
||||
"humanized_name": "my_tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": {"type": "object", "properties": {}},
|
||||
"init_params_schema": {"type": "object", "properties": {}},
|
||||
"env_vars": [{"name": "API_KEY", "description": "API key", "required": True, "default": None}],
|
||||
}
|
||||
]
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file,
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata},
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
|
||||
@@ -363,3 +363,290 @@ def test_get_crews_ignores_template_directories(
|
||||
utils.get_crews()
|
||||
|
||||
assert not template_crew_detected
|
||||
|
||||
|
||||
# Tests for extract_tools_metadata
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_project(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty project."""
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when no __init__.py exists."""
|
||||
(temp_project_dir / "some_file.py").write_text("print('hello')")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty __init__.py."""
|
||||
create_init_file(temp_project_dir, "")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_all_variable(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when __all__ is not defined."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"from crewai.tools import BaseTool\n\nclass MyTool(BaseTool):\n pass",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_valid_base_tool_class(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from a valid BaseTool class."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
assert metadata[0]["humanized_name"] == "my_tool"
|
||||
assert metadata[0]["description"] == "A test tool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_args_schema(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts run_params_schema from args_schema."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
class MyToolInput(BaseModel):
|
||||
query: str
|
||||
limit: int = 10
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
args_schema: type[BaseModel] = MyToolInput
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
run_params = metadata[0]["run_params_schema"]
|
||||
assert "properties" in run_params
|
||||
assert "query" in run_params["properties"]
|
||||
assert "limit" in run_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_env_vars(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts env_vars."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(name="MY_API_KEY", description="API key for service", required=True),
|
||||
EnvVar(name="MY_OPTIONAL_VAR", description="Optional var", required=False, default="default_value"),
|
||||
]
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
env_vars = metadata[0]["env_vars"]
|
||||
assert len(env_vars) == 2
|
||||
assert env_vars[0]["name"] == "MY_API_KEY"
|
||||
assert env_vars[0]["description"] == "API key for service"
|
||||
assert env_vars[0]["required"] is True
|
||||
assert env_vars[1]["name"] == "MY_OPTIONAL_VAR"
|
||||
assert env_vars[1]["required"] is False
|
||||
assert env_vars[1]["default"] == "default_value"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_env_vars_field_default_factory(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts env_vars declared with Field(default_factory=...)."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
from pydantic import Field
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(name="MY_TOOL_API", description="API token for my tool", required=True),
|
||||
]
|
||||
)
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
env_vars = metadata[0]["env_vars"]
|
||||
assert len(env_vars) == 1
|
||||
assert env_vars[0]["name"] == "MY_TOOL_API"
|
||||
assert env_vars[0]["description"] == "API token for my tool"
|
||||
assert env_vars[0]["required"] is True
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_custom_init_params(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts init_params_schema with custom params."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
api_endpoint: str = "https://api.example.com"
|
||||
timeout: int = 30
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
init_params = metadata[0]["init_params_schema"]
|
||||
assert "properties" in init_params
|
||||
# Custom params should be included
|
||||
assert "api_endpoint" in init_params["properties"]
|
||||
assert "timeout" in init_params["properties"]
|
||||
# Base params should be filtered out
|
||||
assert "name" not in init_params["properties"]
|
||||
assert "description" not in init_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_tools(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple tools."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class FirstTool(BaseTool):
|
||||
name: str = "first_tool"
|
||||
description: str = "First test tool"
|
||||
|
||||
class SecondTool(BaseTool):
|
||||
name: str = "second_tool"
|
||||
description: str = "Second test tool"
|
||||
|
||||
__all__ = ['FirstTool', 'SecondTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "FirstTool" in names
|
||||
assert "SecondTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_init_files(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple __init__.py files."""
|
||||
# Create tool in root __init__.py
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class RootTool(BaseTool):
|
||||
name: str = "root_tool"
|
||||
description: str = "Root tool"
|
||||
|
||||
__all__ = ['RootTool']
|
||||
""",
|
||||
)
|
||||
|
||||
# Create nested package with another tool
|
||||
nested_dir = temp_project_dir / "nested"
|
||||
nested_dir.mkdir()
|
||||
create_init_file(
|
||||
nested_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class NestedTool(BaseTool):
|
||||
name: str = "nested_tool"
|
||||
description: str = "Nested tool"
|
||||
|
||||
__all__ = ['NestedTool']
|
||||
""",
|
||||
)
|
||||
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "RootTool" in names
|
||||
assert "NestedTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_ignores_non_tool_exports(temp_project_dir):
|
||||
"""Test that extract_tools_metadata ignores non-BaseTool exports."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
def not_a_tool():
|
||||
pass
|
||||
|
||||
SOME_CONSTANT = "value"
|
||||
|
||||
__all__ = ['MyTool', 'not_a_tool', 'SOME_CONSTANT']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_import_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on import error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from nonexistent_module import something
|
||||
|
||||
class MyTool(BaseTool):
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_syntax_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on syntax error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
# Missing closing parenthesis
|
||||
def __init__(self, name:
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
@@ -185,9 +185,14 @@ def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
@patch("crewai.cli.tools.main.ToolCommand._print_current_organization")
|
||||
def test_publish_when_not_in_sync_and_force(
|
||||
mock_print_org,
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -222,6 +227,7 @@ def test_publish_when_not_in_sync_and_force(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
mock_print_org.assert_called_once()
|
||||
|
||||
@@ -242,7 +248,12 @@ def test_publish_when_not_in_sync_and_force(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_success(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -277,6 +288,7 @@ def test_publish_success(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
|
||||
|
||||
@@ -295,7 +307,12 @@ def test_publish_success(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_failure(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -336,7 +353,12 @@ def test_publish_failure(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_api_error(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -362,6 +384,63 @@ def test_publish_api_error(
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
side_effect=Exception("Failed to extract metadata"),
|
||||
)
|
||||
def test_publish_metadata_extraction_failure_continues_with_warning(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
"""Test that metadata extraction failure shows warning but continues publishing."""
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Warning: Could not extract tool metadata" in output
|
||||
assert "Publishing will continue without detailed metadata" in output
|
||||
assert "No tool metadata extracted" in output
|
||||
mock_publish.assert_called_once_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[],
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.Settings")
|
||||
def test_print_current_organization_with_org(mock_settings, capsys, tool_command):
|
||||
mock_settings_instance = MagicMock()
|
||||
|
||||
@@ -132,12 +132,12 @@ def test_embedding_configuration_flow(
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
"config": {"model_name": "all-MiniLM-L6-v2"},
|
||||
}
|
||||
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
mock_get_embedding.assert_called_once_with(storage.embedder)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
|
||||
@@ -125,8 +125,8 @@ def test_anthropic_specific_parameters():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
assert llm._client.max_retries == 5
|
||||
assert llm._client.timeout == 60
|
||||
|
||||
|
||||
def test_anthropic_completion_call():
|
||||
@@ -563,8 +563,8 @@ def test_anthropic_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'messages')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'messages')
|
||||
|
||||
|
||||
def test_anthropic_token_usage_tracking():
|
||||
@@ -574,7 +574,7 @@ def test_anthropic_token_usage_tracking():
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the Anthropic response with usage information
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
with patch.object(llm._client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
|
||||
@@ -639,14 +639,14 @@ def test_anthropic_thinking():
|
||||
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
original_create = llm.client.messages.create
|
||||
original_create = llm._client.messages.create
|
||||
captured_params = {}
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_params.update(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
result = llm.call("What is the weather in Tokyo?")
|
||||
|
||||
assert result is not None
|
||||
@@ -677,14 +677,14 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
# Capture all messages.create calls to verify thinking blocks are included
|
||||
original_create = llm.client.messages.create
|
||||
original_create = llm._client.messages.create
|
||||
captured_calls = []
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_calls.append(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
# First call - establishes context and generates thinking blocks
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
first_result = llm.call(messages)
|
||||
@@ -695,8 +695,8 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert len(first_result) > 0
|
||||
|
||||
# Verify thinking blocks were stored after first response
|
||||
assert len(llm.previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm.previous_thinking_blocks[0]
|
||||
assert len(llm._previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm._previous_thinking_blocks[0]
|
||||
assert first_thinking["type"] == "thinking"
|
||||
assert "thinking" in first_thinking
|
||||
assert "signature" in first_thinking
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_azure_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Azure client responses
|
||||
with patch.object(completion.client, 'complete') as mock_complete:
|
||||
with patch.object(completion._client, 'complete') as mock_complete:
|
||||
# Mock tool call in response with proper type
|
||||
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
|
||||
mock_tool_call.function.name = "get_weather"
|
||||
@@ -698,7 +698,7 @@ def test_azure_environment_variable_endpoint():
|
||||
}):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert llm.client is not None
|
||||
assert llm._client is not None
|
||||
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||
|
||||
|
||||
@@ -709,7 +709,7 @@ def test_azure_token_usage_tracking():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock the Azure response with usage information
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "test response"
|
||||
mock_message.tool_calls = None
|
||||
@@ -747,7 +747,7 @@ def test_azure_http_error_handling():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock an HTTP error
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
|
||||
|
||||
with pytest.raises(HttpResponseError):
|
||||
@@ -966,7 +966,7 @@ def test_azure_improved_error_messages():
|
||||
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
error_401 = HttpResponseError(message="Unauthorized")
|
||||
error_401.status_code = 401
|
||||
mock_complete.side_effect = error_401
|
||||
@@ -1327,7 +1327,7 @@ def test_azure_stop_words_not_applied_to_structured_output():
|
||||
# Without the fix, this would be truncated at "Observation:" breaking the JSON
|
||||
json_response = '{"finding": "The data shows growth", "observation": "Observation: This confirms the hypothesis"}'
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = json_response
|
||||
mock_message.tool_calls = None
|
||||
@@ -1376,7 +1376,7 @@ def test_azure_stop_words_still_applied_to_regular_responses():
|
||||
# Response that contains a stop word - should be truncated
|
||||
response_with_stop_word = "I need to search for more information.\n\nAction: search\nObservation: Found results"
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = response_with_stop_word
|
||||
mock_message.tool_calls = None
|
||||
|
||||
@@ -674,7 +674,7 @@ def test_bedrock_token_usage_tracking():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the Bedrock response with usage information
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
@@ -719,7 +719,7 @@ def test_bedrock_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Bedrock client responses
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
# First response: tool use request
|
||||
tool_use_response = {
|
||||
'output': {
|
||||
@@ -805,7 +805,7 @@ def test_bedrock_client_error_handling():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test ValidationException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ValidationException',
|
||||
@@ -819,7 +819,7 @@ def test_bedrock_client_error_handling():
|
||||
assert "validation" in str(exc_info.value).lower()
|
||||
|
||||
# Test ThrottlingException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ThrottlingException',
|
||||
@@ -861,7 +861,7 @@ def test_bedrock_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
|
||||
@@ -556,8 +556,8 @@ def test_gemini_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'models')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'models')
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
|
||||
@@ -655,7 +655,7 @@ def test_gemini_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(llm._client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Hello"
|
||||
mock_response.candidates = []
|
||||
|
||||
@@ -371,11 +371,11 @@ def test_openai_client_setup_with_extra_arguments():
|
||||
assert llm.top_p == 0.5
|
||||
|
||||
# Check that client parameters are properly configured
|
||||
assert llm.client.max_retries == 3
|
||||
assert llm.client.timeout == 30
|
||||
assert llm._client.max_retries == 3
|
||||
assert llm._client.timeout == 30
|
||||
|
||||
# Test that parameters are properly used in API calls
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -396,7 +396,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
|
||||
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -507,7 +507,7 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
|
||||
with patch.object(llm._client.beta.chat.completions, "stream") as mock_stream:
|
||||
# Create mock chunks with content.delta event structure
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.type = "content.delta"
|
||||
@@ -1523,6 +1523,69 @@ def test_openai_stop_words_not_applied_to_structured_output():
|
||||
assert "Observation:" in result.observation
|
||||
|
||||
|
||||
def test_openai_gpt5_models_do_not_support_stop_words():
|
||||
"""
|
||||
Test that GPT-5 family models do not support stop words via the API.
|
||||
GPT-5 models reject the 'stop' parameter, so stop words must be
|
||||
applied client-side only.
|
||||
"""
|
||||
gpt5_models = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
"gpt-5-pro",
|
||||
"gpt-5.1",
|
||||
"gpt-5.1-chat",
|
||||
"gpt-5.2",
|
||||
"gpt-5.2-chat",
|
||||
]
|
||||
|
||||
for model_name in gpt5_models:
|
||||
llm = OpenAICompletion(model=model_name)
|
||||
assert llm.supports_stop_words() == False, (
|
||||
f"Expected {model_name} to NOT support stop words"
|
||||
)
|
||||
|
||||
|
||||
def test_openai_non_gpt5_models_support_stop_words():
|
||||
"""
|
||||
Test that non-GPT-5 models still support stop words normally.
|
||||
"""
|
||||
supported_models = [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4-turbo",
|
||||
]
|
||||
|
||||
for model_name in supported_models:
|
||||
llm = OpenAICompletion(model=model_name)
|
||||
assert llm.supports_stop_words() == True, (
|
||||
f"Expected {model_name} to support stop words"
|
||||
)
|
||||
|
||||
|
||||
def test_openai_gpt5_still_applies_stop_words_client_side():
|
||||
"""
|
||||
Test that GPT-5 models still truncate responses at stop words client-side
|
||||
via _apply_stop_words(), even though they don't send 'stop' to the API.
|
||||
"""
|
||||
llm = OpenAICompletion(
|
||||
model="gpt-5.2",
|
||||
stop=["Observation:", "Final Answer:"],
|
||||
)
|
||||
|
||||
assert llm.supports_stop_words() == False
|
||||
|
||||
response = "I need to search.\n\nAction: search\nObservation: Found results"
|
||||
result = llm._apply_stop_words(response)
|
||||
|
||||
assert "Observation:" not in result
|
||||
assert "Found results" not in result
|
||||
assert "I need to search" in result
|
||||
|
||||
|
||||
def test_openai_stop_words_still_applied_to_regular_responses():
|
||||
"""
|
||||
Test that stop words ARE still applied for regular (non-structured) responses.
|
||||
@@ -1767,7 +1830,7 @@ def test_openai_responses_api_cached_prompt_tokens_with_tools():
|
||||
}
|
||||
]
|
||||
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='response')
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='responses')
|
||||
|
||||
# First call with tool
|
||||
llm.call(
|
||||
@@ -1843,7 +1906,7 @@ def test_openai_streaming_returns_tool_calls_without_available_functions():
|
||||
mock_chunk_3.id = "chatcmpl-1"
|
||||
|
||||
with patch.object(
|
||||
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
llm._client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
):
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
@@ -1934,7 +1997,7 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct
|
||||
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
|
||||
with patch.object(
|
||||
llm.async_client.chat.completions, "create", side_effect=mock_create
|
||||
llm._async_client.chat.completions, "create", side_effect=mock_create
|
||||
):
|
||||
result = await llm.acall(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
@@ -59,7 +61,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
with pytest.raises(ValidationError):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
|
||||
@@ -682,6 +682,126 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_litellm_gpt5_call_succeeds_without_stop_error():
|
||||
"""
|
||||
Integration test: GPT-5 call succeeds when stop words are configured,
|
||||
because stop is omitted from API params and applied client-side.
|
||||
"""
|
||||
llm = LLM(model="gpt-5", stop=["Observation:"], is_litellm=True)
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_litellm_gpt5_does_not_send_stop_in_params():
|
||||
"""
|
||||
Test that the LiteLLM fallback path does not include 'stop' in API params
|
||||
for GPT-5.x models, since they reject it at the API level.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
params = llm._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert params.get("stop") is None, (
|
||||
"GPT-5.x models should not have 'stop' in API params"
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_non_gpt5_sends_stop_in_params():
|
||||
"""
|
||||
Test that the LiteLLM fallback path still includes 'stop' in API params
|
||||
for models that support it.
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
params = llm._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert params.get("stop") == ["Observation:"], (
|
||||
"Non-GPT-5 models should have 'stop' in API params"
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_retry_catches_litellm_unsupported_params_error(caplog):
|
||||
"""
|
||||
Test that the retry logic catches LiteLLM's UnsupportedParamsError format
|
||||
("does not support parameters") in addition to the OpenAI API format.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
litellm_error = Exception(
|
||||
"litellm.UnsupportedParamsError: openai does not support parameters: "
|
||||
"['stop'], for model=openai/gpt-5.2."
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
pytest.skip("litellm is not installed; skipping LiteLLM retry test")
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise litellm_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
assert "stop" in llm.additional_params.get("additional_drop_params", [])
|
||||
|
||||
|
||||
def test_litellm_retry_catches_openai_api_stop_error(caplog):
|
||||
"""
|
||||
Test that the retry logic still catches the OpenAI API error format
|
||||
("Unsupported parameter: 'stop'").
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
api_error = Exception(
|
||||
"Unsupported parameter: 'stop' is not supported with this model."
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise api_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
with caplog.at_level(logging.INFO):
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
assert "stop" in llm.additional_params.get("additional_drop_params", [])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_llm():
|
||||
return LLM(model="ollama/llama3.2:3b", is_litellm=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, ClassVar
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
@@ -372,8 +372,11 @@ def test_internal_crew_with_mcp():
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.__class__ = BaseLLM
|
||||
class _StubLLM(BaseLLM):
|
||||
def call(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
mock_llm = create_autospec(_StubLLM(model="stub"), instance=True)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
|
||||
Reference in New Issue
Block a user