mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-03-31 16:18:14 +00:00
Compare commits
8 Commits
docs/file-
...
gl/refacto
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1dd059a395 | ||
|
|
297f7a0426 | ||
|
|
dfc0f9a317 | ||
|
|
ef79456968 | ||
|
|
6c7ea422e7 | ||
|
|
bb9bcd6823 | ||
|
|
ac14b9127e | ||
|
|
98b7626784 |
@@ -134,10 +134,6 @@ result = flow.kickoff(
|
||||
)
|
||||
```
|
||||
|
||||
<Note type="info" title="CrewAI Platform Integration">
|
||||
When deployed on CrewAI Platform, `ImageFile`, `PDFFile`, and other file-typed fields in your flow state automatically get a file upload UI. Users can drag and drop files directly in the Platform interface. Files are stored securely and passed to agents using provider-specific optimizations (inline base64, file upload APIs, or URL references depending on the provider).
|
||||
</Note>
|
||||
|
||||
### With Standalone Agents
|
||||
|
||||
Pass files directly to agent kickoff:
|
||||
|
||||
@@ -341,87 +341,6 @@ flow.kickoff()
|
||||
|
||||
By providing both unstructured and structured state management options, CrewAI Flows empowers developers to build AI workflows that are both flexible and robust, catering to a wide range of application requirements.
|
||||
|
||||
### File Inputs
|
||||
|
||||
When using structured state, you can include file-typed fields using classes from `crewai-files`. This enables file uploads as part of your flow's input:
|
||||
|
||||
```python
|
||||
from crewai.flow.flow import Flow, start
|
||||
from crewai_files import ImageFile, PDFFile
|
||||
from pydantic import BaseModel
|
||||
|
||||
class OnboardingState(BaseModel):
|
||||
document: PDFFile # File upload
|
||||
cover_image: ImageFile # Image upload
|
||||
title: str = "" # Text input
|
||||
|
||||
class OnboardingFlow(Flow[OnboardingState]):
|
||||
@start()
|
||||
def process_upload(self):
|
||||
# Access files directly from state
|
||||
print(f"Processing: {self.state.title}")
|
||||
return self.state.document
|
||||
```
|
||||
|
||||
When deployed on **CrewAI Platform**, file-typed fields automatically render as file upload dropzones in the UI. Users can drag and drop files, which are then passed to your flow.
|
||||
|
||||
**Kicking off with files via API:**
|
||||
|
||||
The `/kickoff` endpoint auto-detects the request format:
|
||||
- **JSON body** → normal kickoff
|
||||
- **multipart/form-data** → file upload + kickoff
|
||||
|
||||
API users can also pass URL strings directly to file-typed fields—Pydantic coerces them automatically.
|
||||
|
||||
### API Usage
|
||||
|
||||
#### Option 1: Multipart kickoff (recommended)
|
||||
|
||||
Send files directly with the kickoff request:
|
||||
|
||||
```bash
|
||||
# With files (multipart) — same endpoint
|
||||
curl -X POST https://your-deployment.crewai.com/kickoff \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN' \
|
||||
-F 'inputs={"company_name": "Einstein"}' \
|
||||
-F 'cnh_image=@/path/to/document.jpg'
|
||||
```
|
||||
|
||||
Files are automatically stored and converted to `FileInput` objects. The agent receives the file with provider-specific optimization (inline base64, file upload API, or URL reference depending on the LLM provider).
|
||||
|
||||
#### Option 2: JSON kickoff (no files)
|
||||
|
||||
```bash
|
||||
# Without files (JSON) — same endpoint
|
||||
curl -X POST https://your-deployment.crewai.com/kickoff \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"inputs": {"company_name": "Einstein"}}'
|
||||
```
|
||||
|
||||
#### Option 3: Separate upload + kickoff
|
||||
|
||||
Upload files first, then reference them:
|
||||
|
||||
```bash
|
||||
# Step 1: Upload
|
||||
curl -X POST https://your-deployment.crewai.com/files \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN' \
|
||||
-F 'file=@/path/to/document.jpg' \
|
||||
-F 'field_name=cnh_image'
|
||||
# Returns: {"url": "https://...", "field_name": "cnh_image"}
|
||||
|
||||
# Step 2: Kickoff with URL
|
||||
curl -X POST https://your-deployment.crewai.com/kickoff \
|
||||
-H 'Authorization: Bearer YOUR_TOKEN' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"inputs": {"company_name": "Einstein"}, "inputFiles": {"cnh_image": "https://..."}}'
|
||||
```
|
||||
|
||||
#### On CrewAI Platform
|
||||
|
||||
When using the Platform UI, file-typed fields automatically render as drag-and-drop upload zones. No API calls needed—just drop the file and click Run.
|
||||
|
||||
## Flow Persistence
|
||||
|
||||
The @persist decorator enables automatic state persistence in CrewAI Flows, allowing you to maintain flow state across restarts or different workflow executions. This decorator can be applied at either the class level or method level, providing flexibility in how you manage state persistence.
|
||||
|
||||
@@ -6,6 +6,7 @@ import warnings
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.agent.planning_config import PlanningConfig
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.crew import Crew
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.flow.flow import Flow
|
||||
@@ -19,6 +20,9 @@ from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.telemetry.telemetry import Telemetry
|
||||
|
||||
|
||||
CrewAgentExecutor.model_rebuild()
|
||||
|
||||
|
||||
def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
"""Suppress Pydantic deprecation warnings using targeted monkey patch."""
|
||||
original_warn = warnings.warn
|
||||
|
||||
@@ -25,7 +25,6 @@ from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
@@ -167,10 +166,10 @@ class Agent(BaseAgent):
|
||||
default=True,
|
||||
description="Use system prompt for the agent.",
|
||||
)
|
||||
llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
function_calling_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
system_template: str | None = Field(
|
||||
|
||||
@@ -12,7 +12,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -185,7 +184,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
default=None,
|
||||
description="Knowledge sources for the agent.",
|
||||
)
|
||||
knowledge_storage: InstanceOf[BaseKnowledgeStorage] | None = Field(
|
||||
knowledge_storage: BaseKnowledgeStorage | None = Field(
|
||||
default=None,
|
||||
description="Custom knowledge storage for the agent.",
|
||||
)
|
||||
|
||||
@@ -14,8 +14,15 @@ import inspect
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel, GetCoreSchemaHandler, ValidationError
|
||||
from pydantic_core import CoreSchema, core_schema
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
ValidationError,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin
|
||||
from crewai.agents.parser import (
|
||||
@@ -23,6 +30,7 @@ from crewai.agents.parser import (
|
||||
AgentFinish,
|
||||
OutputParserError,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.core.providers.human_input import ExecutorContext, get_provider
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.logging_events import (
|
||||
@@ -38,6 +46,9 @@ from crewai.hooks.tool_hooks import (
|
||||
get_after_tool_call_hooks,
|
||||
get_before_tool_call_hooks,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
@@ -59,106 +70,65 @@ from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.file_store import aget_all_files, get_all_files
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent import Agent
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.crew import Crew
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities.prompts import StandardPromptResult, SystemPromptResult
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
class CrewAgentExecutor(BaseModel, CrewAgentExecutorMixin):
|
||||
"""Executor for crew agents.
|
||||
|
||||
Manages the execution lifecycle of an agent including prompt formatting,
|
||||
LLM interactions, tool execution, and feedback handling.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
llm: BaseLLM,
|
||||
task: Task,
|
||||
crew: Crew,
|
||||
agent: Agent,
|
||||
prompt: SystemPromptResult | StandardPromptResult,
|
||||
max_iter: int,
|
||||
tools: list[CrewStructuredTool],
|
||||
tools_names: str,
|
||||
stop_words: list[str],
|
||||
tools_description: str,
|
||||
tools_handler: ToolsHandler,
|
||||
step_callback: Any = None,
|
||||
original_tools: list[BaseTool] | None = None,
|
||||
function_calling_llm: BaseLLM | Any | None = None,
|
||||
respect_context_window: bool = False,
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
i18n: I18N | None = None,
|
||||
) -> None:
|
||||
"""Initialize executor.
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
Args:
|
||||
llm: Language model instance.
|
||||
task: Task to execute.
|
||||
crew: Crew instance.
|
||||
agent: Agent to execute.
|
||||
prompt: Prompt templates.
|
||||
max_iter: Maximum iterations.
|
||||
tools: Available tools.
|
||||
tools_names: Tool names string.
|
||||
stop_words: Stop word list.
|
||||
tools_description: Tool descriptions.
|
||||
tools_handler: Tool handler instance.
|
||||
step_callback: Optional step callback.
|
||||
original_tools: Original tool list.
|
||||
function_calling_llm: Optional function calling LLM.
|
||||
respect_context_window: Respect context limits.
|
||||
request_within_rpm_limit: RPM limit check function.
|
||||
callbacks: Optional callbacks list.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
"""
|
||||
self._i18n: I18N = i18n or get_i18n()
|
||||
self.llm = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
self.prompt = prompt
|
||||
self.tools = tools
|
||||
self.tools_names = tools_names
|
||||
self.stop = stop_words
|
||||
self.max_iter = max_iter
|
||||
self.callbacks = callbacks or []
|
||||
self._printer: Printer = Printer()
|
||||
self.tools_handler = tools_handler
|
||||
self.original_tools = original_tools or []
|
||||
self.step_callback = step_callback
|
||||
self.tools_description = tools_description
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.respect_context_window = respect_context_window
|
||||
self.request_within_rpm_limit = request_within_rpm_limit
|
||||
self.response_model = response_model
|
||||
self.ask_for_human_input = False
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
llm: BaseLLM
|
||||
task: Task | None = None
|
||||
crew: Crew | None = None
|
||||
agent: Agent
|
||||
prompt: SystemPromptResult | StandardPromptResult
|
||||
max_iter: int
|
||||
tools: list[CrewStructuredTool]
|
||||
tools_names: str
|
||||
stop: list[str] = Field(alias="stop_words")
|
||||
tools_description: str
|
||||
tools_handler: ToolsHandler
|
||||
step_callback: Any = None
|
||||
original_tools: list[BaseTool] = Field(default_factory=list)
|
||||
function_calling_llm: BaseLLM | Any | None = None
|
||||
respect_context_window: bool = False
|
||||
request_within_rpm_limit: Callable[[], bool] | None = None
|
||||
callbacks: list[Any] = Field(default_factory=list)
|
||||
response_model: type[BaseModel] | None = None
|
||||
i18n: I18N | None = Field(default=None, exclude=True)
|
||||
ask_for_human_input: bool = False
|
||||
messages: list[LLMMessage] = Field(default_factory=list)
|
||||
iterations: int = 0
|
||||
log_error_after: int = 3
|
||||
before_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
after_llm_call_hooks: list[Callable[..., Any]] = Field(default_factory=list)
|
||||
_i18n: I18N = PrivateAttr()
|
||||
_printer: Printer = PrivateAttr(default_factory=Printer)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_executor(self) -> Self:
|
||||
self._i18n = self.i18n or get_i18n()
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -171,6 +141,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
return self
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -1687,14 +1658,3 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
return format_message_for_llm(
|
||||
self._i18n.slice("feedback_instructions").format(feedback=feedback)
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def __get_pydantic_core_schema__(
|
||||
cls, _source_type: Any, _handler: GetCoreSchemaHandler
|
||||
) -> CoreSchema:
|
||||
"""Generate Pydantic core schema for BaseClient Protocol.
|
||||
|
||||
This allows the Protocol to be used in Pydantic models without
|
||||
requiring arbitrary_types_allowed=True.
|
||||
"""
|
||||
return core_schema.any_schema()
|
||||
|
||||
@@ -73,6 +73,7 @@ class PlusAPI:
|
||||
description: str | None,
|
||||
encoded_file: str,
|
||||
available_exports: list[dict[str, Any]] | None = None,
|
||||
tools_metadata: list[dict[str, Any]] | None = None,
|
||||
) -> httpx.Response:
|
||||
params = {
|
||||
"handle": handle,
|
||||
@@ -81,6 +82,9 @@ class PlusAPI:
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata}
|
||||
if tools_metadata is not None
|
||||
else None,
|
||||
}
|
||||
return self._make_request("POST", f"{self.TOOLS_RESOURCE}", json=params)
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ from crewai.cli.constants import DEFAULT_CREWAI_ENTERPRISE_URL
|
||||
from crewai.cli.utils import (
|
||||
build_env_with_tool_repository_credentials,
|
||||
extract_available_exports,
|
||||
extract_tools_metadata,
|
||||
get_project_description,
|
||||
get_project_name,
|
||||
get_project_version,
|
||||
@@ -101,6 +102,18 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
console.print(
|
||||
f"[green]Found these tools to publish: {', '.join([e['name'] for e in available_exports])}[/green]"
|
||||
)
|
||||
|
||||
console.print("[bold blue]Extracting tool metadata...[/bold blue]")
|
||||
try:
|
||||
tools_metadata = extract_tools_metadata()
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract tool metadata: {e}[/yellow]\n"
|
||||
f"Publishing will continue without detailed metadata."
|
||||
)
|
||||
tools_metadata = []
|
||||
|
||||
self._print_tools_preview(tools_metadata)
|
||||
self._print_current_organization()
|
||||
|
||||
with tempfile.TemporaryDirectory() as temp_build_dir:
|
||||
@@ -118,7 +131,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
"Project build failed. Please ensure that the command `uv build --sdist` completes successfully.",
|
||||
style="bold red",
|
||||
)
|
||||
raise SystemExit
|
||||
raise SystemExit(1)
|
||||
|
||||
tarball_path = os.path.join(temp_build_dir, tarball_filename)
|
||||
with open(tarball_path, "rb") as file:
|
||||
@@ -134,6 +147,7 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
description=project_description,
|
||||
encoded_file=f"data:application/x-gzip;base64,{encoded_tarball}",
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
self._validate_response(publish_response)
|
||||
@@ -246,6 +260,55 @@ class ToolCommand(BaseCommand, PlusAPIMixin):
|
||||
)
|
||||
raise SystemExit
|
||||
|
||||
def _print_tools_preview(self, tools_metadata: list[dict[str, Any]]) -> None:
|
||||
if not tools_metadata:
|
||||
console.print("[yellow]No tool metadata extracted.[/yellow]")
|
||||
return
|
||||
|
||||
console.print(
|
||||
f"\n[bold]Tools to be published ({len(tools_metadata)}):[/bold]\n"
|
||||
)
|
||||
|
||||
for tool in tools_metadata:
|
||||
console.print(f" [bold cyan]{tool.get('name', 'Unknown')}[/bold cyan]")
|
||||
if tool.get("module"):
|
||||
console.print(f" Module: {tool.get('module')}")
|
||||
console.print(f" Name: {tool.get('humanized_name', 'N/A')}")
|
||||
console.print(
|
||||
f" Description: {tool.get('description', 'N/A')[:80]}{'...' if len(tool.get('description', '')) > 80 else ''}"
|
||||
)
|
||||
|
||||
init_params = tool.get("init_params_schema", {}).get("properties", {})
|
||||
if init_params:
|
||||
required = tool.get("init_params_schema", {}).get("required", [])
|
||||
console.print(" Init parameters:")
|
||||
for param_name, param_info in init_params.items():
|
||||
param_type = param_info.get("type", "any")
|
||||
is_required = param_name in required
|
||||
req_marker = "[red]*[/red]" if is_required else ""
|
||||
default = (
|
||||
f" = {param_info['default']}" if "default" in param_info else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {param_name}: {param_type}{default} {req_marker}"
|
||||
)
|
||||
|
||||
env_vars = tool.get("env_vars", [])
|
||||
if env_vars:
|
||||
console.print(" Environment variables:")
|
||||
for env_var in env_vars:
|
||||
req_marker = "[red]*[/red]" if env_var.get("required") else ""
|
||||
default = (
|
||||
f" (default: {env_var['default']})"
|
||||
if env_var.get("default")
|
||||
else ""
|
||||
)
|
||||
console.print(
|
||||
f" - {env_var['name']}: {env_var.get('description', 'N/A')}{default} {req_marker}"
|
||||
)
|
||||
|
||||
console.print()
|
||||
|
||||
def _print_current_organization(self) -> None:
|
||||
settings = Settings()
|
||||
if settings.org_uuid:
|
||||
|
||||
@@ -1,10 +1,15 @@
|
||||
from functools import reduce
|
||||
from collections.abc import Generator, Mapping
|
||||
from contextlib import contextmanager
|
||||
from functools import lru_cache, reduce
|
||||
import hashlib
|
||||
import importlib.util
|
||||
import inspect
|
||||
from inspect import getmro, isclass, isfunction, ismethod
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
import sys
|
||||
import types
|
||||
from typing import Any, cast, get_type_hints
|
||||
|
||||
import click
|
||||
@@ -544,43 +549,62 @@ def build_env_with_tool_repository_credentials(
|
||||
return env
|
||||
|
||||
|
||||
@contextmanager
|
||||
def _load_module_from_file(
|
||||
init_file: Path, module_name: str | None = None
|
||||
) -> Generator[types.ModuleType | None, None, None]:
|
||||
"""
|
||||
Context manager for loading a module from file with automatic cleanup.
|
||||
|
||||
Yields the loaded module or None if loading fails.
|
||||
"""
|
||||
if module_name is None:
|
||||
module_name = (
|
||||
f"temp_module_{hashlib.sha256(str(init_file).encode()).hexdigest()[:8]}"
|
||||
)
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, init_file)
|
||||
if not spec or not spec.loader:
|
||||
yield None
|
||||
return
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules[module_name] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
yield module
|
||||
finally:
|
||||
sys.modules.pop(module_name, None)
|
||||
|
||||
|
||||
def _load_tools_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load and validate tools from a given __init__.py file.
|
||||
"""
|
||||
spec = importlib.util.spec_from_file_location("temp_module", init_file)
|
||||
|
||||
if not spec or not spec.loader:
|
||||
return []
|
||||
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
sys.modules["temp_module"] = module
|
||||
|
||||
try:
|
||||
spec.loader.exec_module(module)
|
||||
with _load_module_from_file(init_file) as module:
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
if not hasattr(module, "__all__"):
|
||||
console.print(
|
||||
f"Warning: No __all__ defined in {init_file}",
|
||||
style="bold yellow",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return [
|
||||
{
|
||||
"name": name,
|
||||
}
|
||||
for name in module.__all__
|
||||
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
||||
]
|
||||
if not hasattr(module, "__all__"):
|
||||
console.print(
|
||||
f"Warning: No __all__ defined in {init_file}",
|
||||
style="bold yellow",
|
||||
)
|
||||
raise SystemExit(1)
|
||||
|
||||
return [
|
||||
{"name": name}
|
||||
for name in module.__all__
|
||||
if hasattr(module, name) and is_valid_tool(getattr(module, name))
|
||||
]
|
||||
except SystemExit:
|
||||
raise
|
||||
except Exception as e:
|
||||
console.print(f"[red]Warning: Could not load {init_file}: {e!s}[/red]")
|
||||
raise SystemExit(1) from e
|
||||
|
||||
finally:
|
||||
sys.modules.pop("temp_module", None)
|
||||
|
||||
|
||||
def _print_no_tools_warning() -> None:
|
||||
"""
|
||||
@@ -610,3 +634,242 @@ def _print_no_tools_warning() -> None:
|
||||
" # ... implementation\n"
|
||||
" return result\n"
|
||||
)
|
||||
|
||||
|
||||
def extract_tools_metadata(dir_path: str = "src") -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract rich metadata from tool classes in the project.
|
||||
|
||||
Returns a list of tool metadata dictionaries containing:
|
||||
- name: Class name
|
||||
- humanized_name: From name field default
|
||||
- description: From description field default
|
||||
- run_params_schema: JSON Schema for _run() params (from args_schema)
|
||||
- init_params_schema: JSON Schema for __init__ params (filtered)
|
||||
- env_vars: List of environment variable dicts
|
||||
"""
|
||||
tools_metadata: list[dict[str, Any]] = []
|
||||
|
||||
for init_file in Path(dir_path).glob("**/__init__.py"):
|
||||
tools = _extract_tool_metadata_from_init(init_file)
|
||||
tools_metadata.extend(tools)
|
||||
|
||||
return tools_metadata
|
||||
|
||||
|
||||
def _extract_tool_metadata_from_init(init_file: Path) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Load module from init file and extract metadata from valid tool classes.
|
||||
"""
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
try:
|
||||
with _load_module_from_file(init_file) as module:
|
||||
if module is None:
|
||||
return []
|
||||
|
||||
exported_names = getattr(module, "__all__", None)
|
||||
if not exported_names:
|
||||
return []
|
||||
|
||||
tools_metadata = []
|
||||
for name in exported_names:
|
||||
obj = getattr(module, name, None)
|
||||
if obj is None or not (
|
||||
inspect.isclass(obj) and issubclass(obj, BaseTool)
|
||||
):
|
||||
continue
|
||||
if tool_info := _extract_single_tool_metadata(obj):
|
||||
tools_metadata.append(tool_info)
|
||||
|
||||
return tools_metadata
|
||||
except Exception as e:
|
||||
console.print(
|
||||
f"[yellow]Warning: Could not extract metadata from {init_file}: {e}[/yellow]"
|
||||
)
|
||||
return []
|
||||
|
||||
|
||||
def _extract_single_tool_metadata(tool_class: type) -> dict[str, Any] | None:
|
||||
"""
|
||||
Extract metadata from a single tool class.
|
||||
"""
|
||||
try:
|
||||
core_schema = cast(Any, tool_class).__pydantic_core_schema__
|
||||
if not core_schema:
|
||||
return None
|
||||
|
||||
schema = _unwrap_schema(core_schema)
|
||||
fields = schema.get("schema", {}).get("fields", {})
|
||||
|
||||
try:
|
||||
file_path = inspect.getfile(tool_class)
|
||||
relative_path = Path(file_path).relative_to(Path.cwd())
|
||||
module_path = relative_path.with_suffix("")
|
||||
if module_path.parts[0] == "src":
|
||||
module_path = Path(*module_path.parts[1:])
|
||||
if module_path.name == "__init__":
|
||||
module_path = module_path.parent
|
||||
module = ".".join(module_path.parts)
|
||||
except (TypeError, ValueError):
|
||||
module = tool_class.__module__
|
||||
|
||||
return {
|
||||
"name": tool_class.__name__,
|
||||
"module": module,
|
||||
"humanized_name": _extract_field_default(
|
||||
fields.get("name"), fallback=tool_class.__name__
|
||||
),
|
||||
"description": str(
|
||||
_extract_field_default(fields.get("description"))
|
||||
).strip(),
|
||||
"run_params_schema": _extract_run_params_schema(fields.get("args_schema")),
|
||||
"init_params_schema": _extract_init_params_schema(tool_class),
|
||||
"env_vars": _extract_env_vars(fields.get("env_vars")),
|
||||
}
|
||||
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
|
||||
def _unwrap_schema(schema: Mapping[str, Any] | dict[str, Any]) -> dict[str, Any]:
|
||||
"""
|
||||
Unwrap nested schema structures to get to the actual schema definition.
|
||||
"""
|
||||
result: dict[str, Any] = dict(schema)
|
||||
while (
|
||||
result.get("type")
|
||||
in {"function-after", "function-before", "function-wrap", "default"}
|
||||
and "schema" in result
|
||||
):
|
||||
result = dict(result["schema"])
|
||||
if result.get("type") == "definitions" and "schema" in result:
|
||||
result = dict(result["schema"])
|
||||
return result
|
||||
|
||||
|
||||
def _extract_field_default(
|
||||
field: dict[str, Any] | None, fallback: str | list[Any] = ""
|
||||
) -> str | list[Any] | int:
|
||||
"""
|
||||
Extract the default value from a field schema.
|
||||
"""
|
||||
if not field:
|
||||
return fallback
|
||||
|
||||
schema = field.get("schema", {})
|
||||
default = schema.get("default")
|
||||
return default if isinstance(default, (list, str, int)) else fallback
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def _get_schema_generator() -> type:
|
||||
"""Get a SchemaGenerator that omits non-serializable defaults."""
|
||||
from pydantic.json_schema import GenerateJsonSchema
|
||||
from pydantic_core import PydanticOmit
|
||||
|
||||
class SchemaGenerator(GenerateJsonSchema):
|
||||
def handle_invalid_for_json_schema(
|
||||
self, schema: Any, error_info: Any
|
||||
) -> dict[str, Any]:
|
||||
raise PydanticOmit
|
||||
|
||||
return SchemaGenerator
|
||||
|
||||
|
||||
def _extract_run_params_schema(
|
||||
args_schema_field: dict[str, Any] | None,
|
||||
) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's run parameters from args_schema field.
|
||||
"""
|
||||
from pydantic import BaseModel
|
||||
|
||||
if not args_schema_field:
|
||||
return {}
|
||||
|
||||
args_schema_class = args_schema_field.get("schema", {}).get("default")
|
||||
if not (
|
||||
inspect.isclass(args_schema_class) and issubclass(args_schema_class, BaseModel)
|
||||
):
|
||||
return {}
|
||||
|
||||
try:
|
||||
return args_schema_class.model_json_schema(
|
||||
schema_generator=_get_schema_generator()
|
||||
)
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
_IGNORED_INIT_PARAMS = frozenset(
|
||||
{
|
||||
"name",
|
||||
"description",
|
||||
"env_vars",
|
||||
"args_schema",
|
||||
"description_updated",
|
||||
"cache_function",
|
||||
"result_as_answer",
|
||||
"max_usage_count",
|
||||
"current_usage_count",
|
||||
"package_dependencies",
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def _extract_init_params_schema(tool_class: type) -> dict[str, Any]:
|
||||
"""
|
||||
Extract JSON Schema for the tool's __init__ parameters, filtering out base fields.
|
||||
"""
|
||||
try:
|
||||
json_schema: dict[str, Any] = cast(Any, tool_class).model_json_schema(
|
||||
schema_generator=_get_schema_generator(), mode="serialization"
|
||||
)
|
||||
filtered_properties = {
|
||||
key: value
|
||||
for key, value in json_schema.get("properties", {}).items()
|
||||
if key not in _IGNORED_INIT_PARAMS
|
||||
}
|
||||
json_schema["properties"] = filtered_properties
|
||||
if "required" in json_schema:
|
||||
json_schema["required"] = [
|
||||
key for key in json_schema["required"] if key in filtered_properties
|
||||
]
|
||||
return json_schema
|
||||
except Exception:
|
||||
return {}
|
||||
|
||||
|
||||
def _extract_env_vars(env_vars_field: dict[str, Any] | None) -> list[dict[str, Any]]:
|
||||
"""
|
||||
Extract environment variable definitions from env_vars field.
|
||||
"""
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
if not env_vars_field:
|
||||
return []
|
||||
|
||||
schema = env_vars_field.get("schema", {})
|
||||
default = schema.get("default")
|
||||
if default is None:
|
||||
default_factory = schema.get("default_factory")
|
||||
if callable(default_factory):
|
||||
try:
|
||||
default = default_factory()
|
||||
except Exception:
|
||||
default = []
|
||||
|
||||
if not isinstance(default, list):
|
||||
return []
|
||||
|
||||
return [
|
||||
{
|
||||
"name": env_var.name,
|
||||
"description": env_var.description,
|
||||
"required": env_var.required,
|
||||
"default": env_var.default,
|
||||
}
|
||||
for env_var in default
|
||||
if isinstance(env_var, EnvVar)
|
||||
]
|
||||
|
||||
@@ -22,7 +22,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
Json,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
@@ -176,7 +175,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
_rpm_controller: RPMController = PrivateAttr()
|
||||
_logger: Logger = PrivateAttr()
|
||||
_file_handler: FileHandler = PrivateAttr()
|
||||
_cache_handler: InstanceOf[CacheHandler] = PrivateAttr(default_factory=CacheHandler)
|
||||
_cache_handler: CacheHandler = PrivateAttr(default_factory=CacheHandler)
|
||||
_memory: Memory | MemoryScope | MemorySlice | None = PrivateAttr(default=None)
|
||||
_train: bool | None = PrivateAttr(default=False)
|
||||
_train_iteration: int | None = PrivateAttr()
|
||||
@@ -210,13 +209,13 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=None,
|
||||
description="Metrics for the LLM usage during all tasks execution.",
|
||||
)
|
||||
manager_llm: str | InstanceOf[BaseLLM] | None = Field(
|
||||
manager_llm: str | BaseLLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
manager_agent: BaseAgent | None = Field(
|
||||
description="Custom agent that will be used as manager.", default=None
|
||||
)
|
||||
function_calling_llm: str | InstanceOf[LLM] | None = Field(
|
||||
function_calling_llm: str | LLM | None = Field(
|
||||
description="Language model that will run the agent.", default=None
|
||||
)
|
||||
config: Json[dict[str, Any]] | dict[str, Any] | None = Field(default=None)
|
||||
@@ -267,7 +266,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default=False,
|
||||
description="Plan the crew execution and add the plan to the crew.",
|
||||
)
|
||||
planning_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
planning_llm: str | BaseLLM | Any | None = Field(
|
||||
default=None,
|
||||
description=(
|
||||
"Language model that will run the AgentPlanner if planning is True."
|
||||
@@ -288,7 +287,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"knowledge object."
|
||||
),
|
||||
)
|
||||
chat_llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
chat_llm: str | BaseLLM | Any | None = Field(
|
||||
default=None,
|
||||
description="LLM used to handle chatting with the crew.",
|
||||
)
|
||||
@@ -1800,7 +1799,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: str | InstanceOf[BaseLLM],
|
||||
eval_llm: str | BaseLLM,
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations.
|
||||
|
||||
@@ -1966,37 +1966,6 @@ class AgentExecutor(Flow[AgentExecutorState], CrewAgentExecutorMixin):
|
||||
"original_tool": original_tool,
|
||||
}
|
||||
|
||||
def _extract_tool_name(self, tool_call: Any) -> str:
|
||||
"""Extract tool name from various tool call formats."""
|
||||
if hasattr(tool_call, "function"):
|
||||
return sanitize_tool_name(tool_call.function.name)
|
||||
if hasattr(tool_call, "function_call") and tool_call.function_call:
|
||||
return sanitize_tool_name(tool_call.function_call.name)
|
||||
if hasattr(tool_call, "name"):
|
||||
return sanitize_tool_name(tool_call.name)
|
||||
if isinstance(tool_call, dict):
|
||||
func_info = tool_call.get("function", {})
|
||||
return sanitize_tool_name(
|
||||
func_info.get("name", "") or tool_call.get("name", "unknown")
|
||||
)
|
||||
return "unknown"
|
||||
|
||||
@router(execute_native_tool)
|
||||
def check_native_todo_completion(
|
||||
self,
|
||||
) -> Literal["todo_satisfied", "todo_not_satisfied"]:
|
||||
"""Check if the native tool execution satisfied the active todo.
|
||||
|
||||
Similar to check_todo_completion but for native tool execution path.
|
||||
"""
|
||||
current_todo = self.state.todos.current_todo
|
||||
|
||||
if not current_todo:
|
||||
return "todo_not_satisfied"
|
||||
|
||||
# For native tools, any tool execution satisfies the todo
|
||||
return "todo_satisfied"
|
||||
|
||||
@listen("initialized")
|
||||
def continue_iteration(self) -> Literal["check_iteration"]:
|
||||
"""Bridge listener that connects iteration loop back to iteration check."""
|
||||
|
||||
@@ -3,12 +3,15 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.rag.types import SearchResult
|
||||
|
||||
|
||||
class BaseKnowledgeStorage(ABC):
|
||||
class BaseKnowledgeStorage(BaseModel, ABC):
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
"""Abstract base class for knowledge storage implementations."""
|
||||
|
||||
@abstractmethod
|
||||
|
||||
@@ -3,6 +3,9 @@ import traceback
|
||||
from typing import Any, cast
|
||||
import warnings
|
||||
|
||||
from pydantic import Field, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
from crewai.rag.chromadb.config import ChromaDBConfig
|
||||
from crewai.rag.chromadb.types import ChromaEmbeddingFunctionWrapper
|
||||
@@ -22,31 +25,32 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
search efficiency.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
collection_name: str | None = None
|
||||
embedder: (
|
||||
ProviderSpec
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
self.collection_name = collection_name
|
||||
self._client: BaseClient | None = None
|
||||
| None
|
||||
) = Field(default=None, exclude=True)
|
||||
_client: BaseClient | None = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> Self:
|
||||
warnings.filterwarnings(
|
||||
"ignore",
|
||||
message=r".*'model_fields'.*is deprecated.*",
|
||||
module=r"^chromadb(\.|$)",
|
||||
)
|
||||
|
||||
if embedder:
|
||||
embedding_function = build_embedder(embedder) # type: ignore[arg-type]
|
||||
if self.embedder:
|
||||
embedding_function = build_embedder(self.embedder) # type: ignore[arg-type]
|
||||
config = ChromaDBConfig(
|
||||
embedding_function=cast(
|
||||
ChromaEmbeddingFunctionWrapper, embedding_function
|
||||
)
|
||||
)
|
||||
self._client = create_client(config)
|
||||
return self
|
||||
|
||||
def _get_client(self) -> BaseClient:
|
||||
"""Get the appropriate client - instance-specific or global."""
|
||||
|
||||
@@ -22,7 +22,6 @@ from pydantic import (
|
||||
UUID4,
|
||||
BaseModel,
|
||||
Field,
|
||||
InstanceOf,
|
||||
PrivateAttr,
|
||||
field_validator,
|
||||
model_validator,
|
||||
@@ -204,7 +203,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
role: str = Field(description="Role of the agent")
|
||||
goal: str = Field(description="Goal of the agent")
|
||||
backstory: str = Field(description="Backstory of the agent")
|
||||
llm: str | InstanceOf[BaseLLM] | Any | None = Field(
|
||||
llm: str | BaseLLM | Any | None = Field(
|
||||
default=None, description="Language model that will run the agent"
|
||||
)
|
||||
tools: list[BaseTool] = Field(
|
||||
|
||||
@@ -20,8 +20,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from dotenv import load_dotenv
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
@@ -37,7 +36,12 @@ from crewai.events.types.tool_usage_events import (
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageStartedEvent,
|
||||
)
|
||||
from crewai.llms.base_llm import BaseLLM, get_current_call_id, llm_call_context
|
||||
from crewai.llms.base_llm import (
|
||||
BaseLLM,
|
||||
JsonResponseFormat,
|
||||
get_current_call_id,
|
||||
llm_call_context,
|
||||
)
|
||||
from crewai.llms.constants import (
|
||||
ANTHROPIC_MODELS,
|
||||
AZURE_MODELS,
|
||||
@@ -63,8 +67,6 @@ except ImportError:
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -342,6 +344,27 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
class LLM(BaseLLM):
|
||||
completion_cost: float | None = None
|
||||
timeout: float | int | None = None
|
||||
top_p: float | None = None
|
||||
n: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
max_tokens: int | float | None = None
|
||||
presence_penalty: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
logit_bias: dict[int, float] | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
seed: int | None = None
|
||||
logprobs: int | None = None
|
||||
top_logprobs: int | None = None
|
||||
api_base: str | None = None
|
||||
api_version: str | None = None
|
||||
callbacks: list[Any] | None = None
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None
|
||||
stream: bool = False
|
||||
interceptor: Any = None
|
||||
thinking: Any = None
|
||||
context_window_size: int = 0
|
||||
is_anthropic: bool = False
|
||||
|
||||
def __new__(cls, model: str, is_litellm: bool = False, **kwargs: Any) -> LLM:
|
||||
"""Factory method that routes to native SDK or falls back to LiteLLM.
|
||||
@@ -436,10 +459,7 @@ class LLM(BaseLLM):
|
||||
logger.error(error_msg)
|
||||
raise ImportError(error_msg) from None
|
||||
|
||||
instance = object.__new__(cls)
|
||||
super(LLM, instance).__init__(model=model, is_litellm=True, **kwargs)
|
||||
instance.is_litellm = True
|
||||
return instance
|
||||
return object.__new__(cls)
|
||||
|
||||
@classmethod
|
||||
def _matches_provider_pattern(cls, model: str, provider: str) -> bool:
|
||||
@@ -624,89 +644,23 @@ class LLM(BaseLLM):
|
||||
|
||||
return None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
timeout: float | int | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
n: int | None = None,
|
||||
stop: str | list[str] | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
max_tokens: int | float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
logit_bias: dict[int, float] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
seed: int | None = None,
|
||||
logprobs: int | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
base_url: str | None = None,
|
||||
api_base: str | None = None,
|
||||
api_version: str | None = None,
|
||||
api_key: str | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_llm_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
model = data.get("model", "")
|
||||
data["is_anthropic"] = cls._is_anthropic_model(model)
|
||||
return data
|
||||
|
||||
Note: This __init__ method is only called for fallback instances.
|
||||
Native provider instances handle their own initialization in their respective classes.
|
||||
"""
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key,
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
**kwargs,
|
||||
)
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.base_url = base_url
|
||||
self.api_base = api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.prefer_upload = prefer_upload
|
||||
self.additional_params = {
|
||||
k: v for k, v in kwargs.items() if k not in ("is_litellm", "provider")
|
||||
}
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
self.stream = stream
|
||||
self.interceptor = interceptor
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a list[str]
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
self.set_callbacks(callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
@model_validator(mode="after")
|
||||
def _init_litellm(self) -> LLM:
|
||||
self.is_litellm = True
|
||||
if LITELLM_AVAILABLE:
|
||||
litellm.drop_params = True
|
||||
self.set_callbacks(self.callbacks or [])
|
||||
self.set_env_callbacks()
|
||||
return self
|
||||
|
||||
@staticmethod
|
||||
def _is_anthropic_model(model: str) -> bool:
|
||||
@@ -753,7 +707,7 @@ class LLM(BaseLLM):
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": self.stop or None,
|
||||
"stop": (self.stop or None) if self.supports_stop_words() else None,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
@@ -1825,9 +1779,11 @@ class LLM(BaseLLM):
|
||||
# whether to summarize the content or abort based on the respect_context_window flag
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
error_str = str(e)
|
||||
unsupported_stop = "'stop'" in error_str and (
|
||||
"Unsupported parameter" in error_str
|
||||
or "does not support parameters" in error_str
|
||||
)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
@@ -1961,9 +1917,11 @@ class LLM(BaseLLM):
|
||||
except LLMContextLengthExceededError:
|
||||
raise
|
||||
except Exception as e:
|
||||
unsupported_stop = "Unsupported parameter" in str(
|
||||
e
|
||||
) and "'stop'" in str(e)
|
||||
error_str = str(e)
|
||||
unsupported_stop = "'stop'" in error_str and (
|
||||
"Unsupported parameter" in error_str
|
||||
or "does not support parameters" in error_str
|
||||
)
|
||||
|
||||
if unsupported_stop:
|
||||
if (
|
||||
@@ -2263,6 +2221,10 @@ class LLM(BaseLLM):
|
||||
Note: This method is only used by the litellm fallback path.
|
||||
Native providers override this method with their own implementation.
|
||||
"""
|
||||
model_lower = self.model.lower() if self.model else ""
|
||||
if "gpt-5" in model_lower:
|
||||
return False
|
||||
|
||||
if not LITELLM_AVAILABLE or get_supported_openai_params is None:
|
||||
# When litellm is not available, assume stop words are supported
|
||||
return True
|
||||
@@ -2434,7 +2396,7 @@ class LLM(BaseLLM):
|
||||
**filtered_params,
|
||||
)
|
||||
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None) -> LLM:
|
||||
def __deepcopy__(self, memo: dict[int, Any] | None = None) -> LLM:
|
||||
"""Create a deep copy of the LLM instance."""
|
||||
import copy
|
||||
|
||||
|
||||
@@ -14,10 +14,18 @@ from datetime import datetime
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Final
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal
|
||||
import uuid
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import (
|
||||
AliasChoices,
|
||||
BaseModel,
|
||||
ConfigDict,
|
||||
Field,
|
||||
PrivateAttr,
|
||||
model_validator,
|
||||
)
|
||||
from typing_extensions import TypedDict
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.llm_events import (
|
||||
@@ -51,6 +59,12 @@ if TYPE_CHECKING:
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class JsonResponseFormat(TypedDict):
|
||||
"""Response format requesting raw JSON output (e.g. ``{"type": "json_object"}``)."""
|
||||
|
||||
type: Literal["json_object"]
|
||||
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
_JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTALL)
|
||||
@@ -82,7 +96,7 @@ def get_current_call_id() -> str:
|
||||
return call_id
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
class BaseLLM(BaseModel, ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
@@ -101,56 +115,100 @@ class BaseLLM(ABC):
|
||||
additional_params: Additional provider-specific parameters.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True, populate_by_name=True)
|
||||
|
||||
model: str
|
||||
temperature: float | None = None
|
||||
api_key: str | None = None
|
||||
base_url: str | None = None
|
||||
provider: str = Field(default="openai")
|
||||
prefer_upload: bool = False
|
||||
is_litellm: bool = False
|
||||
stop: list[str] = Field(
|
||||
default_factory=list,
|
||||
validation_alias=AliasChoices("stop", "stop_sequences"),
|
||||
)
|
||||
additional_params: dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
provider: str | None = None,
|
||||
prefer_upload: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
def __setattr__(self, name: str, value: Any) -> None:
|
||||
if name in ("stop", "stop_sequences"):
|
||||
if value is None:
|
||||
value = []
|
||||
elif isinstance(value, str):
|
||||
value = [value]
|
||||
elif not isinstance(value, list):
|
||||
value = list(value)
|
||||
name = "stop"
|
||||
try:
|
||||
super().__setattr__(name, value)
|
||||
except ValueError:
|
||||
if name in self.model_fields:
|
||||
raise # Re-raise validation errors on declared fields
|
||||
# Fallback for attributes not declared as fields (e.g. mock patching)
|
||||
object.__setattr__(self, name, value)
|
||||
except AttributeError:
|
||||
object.__setattr__(self, name, value)
|
||||
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
prefer_upload: Whether to prefer file upload over inline base64.
|
||||
**kwargs: Additional provider-specific parameters.
|
||||
def __delattr__(self, name: str) -> None:
|
||||
try:
|
||||
super().__delattr__(name)
|
||||
except AttributeError:
|
||||
object.__delattr__(self, name)
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Alias for ``stop`` — kept for backward compatibility with provider APIs.
|
||||
|
||||
Writes are handled by ``__setattr__``, which normalizes and redirects
|
||||
``stop_sequences`` assignments to the ``stop`` field.
|
||||
"""
|
||||
if not model:
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
return self.stop
|
||||
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.api_key = api_key
|
||||
self.base_url = base_url
|
||||
self.prefer_upload = prefer_upload
|
||||
# Store additional parameters for provider-specific use
|
||||
self.additional_params = kwargs
|
||||
self._provider = provider or "openai"
|
||||
|
||||
stop = kwargs.pop("stop", None)
|
||||
if stop is None:
|
||||
self.stop: list[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
elif isinstance(stop, list):
|
||||
self.stop = stop
|
||||
else:
|
||||
self.stop = []
|
||||
|
||||
self._token_usage = {
|
||||
_token_usage: dict[str, int] = PrivateAttr(
|
||||
default_factory=lambda: {
|
||||
"total_tokens": 0,
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _validate_init_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if not data.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
# Normalize stop: accept str, list, or None; also accept stop_sequences alias
|
||||
stop_seqs = data.pop("stop_sequences", None)
|
||||
stop = stop_seqs if stop_seqs is not None else data.get("stop")
|
||||
if stop is None:
|
||||
data["stop"] = []
|
||||
elif isinstance(stop, str):
|
||||
data["stop"] = [stop]
|
||||
elif isinstance(stop, list):
|
||||
data["stop"] = stop
|
||||
else:
|
||||
data["stop"] = list(stop)
|
||||
|
||||
# Default provider
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
|
||||
# Collect unknown kwargs into additional_params
|
||||
known_fields = set(cls.model_fields.keys())
|
||||
extras = {k: v for k, v in data.items() if k not in known_fields}
|
||||
for k in extras:
|
||||
data.pop(k)
|
||||
existing = data.get("additional_params") or {}
|
||||
existing.update(extras)
|
||||
data["additional_params"] = existing
|
||||
|
||||
return data
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
|
||||
@@ -174,16 +232,6 @@ class BaseLLM(ABC):
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
return self._provider
|
||||
|
||||
@provider.setter
|
||||
def provider(self, value: str) -> None:
|
||||
"""Set the provider of the LLM."""
|
||||
self._provider = value
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
|
||||
@@ -3,12 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, Final, Literal, TypeGuard, cast
|
||||
from typing import Any, Final, Literal, TypeGuard, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -17,9 +18,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic, transform_schema
|
||||
from anthropic.types import (
|
||||
@@ -150,60 +148,47 @@ class AnthropicCompletion(BaseLLM):
|
||||
offering native tool use, streaming support, and proper message formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int = 4096, # Required for Anthropic
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
tool_search: AnthropicToolSearchConfig | bool | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
model: str = "claude-3-5-sonnet-20241022"
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
max_tokens: int = 4096
|
||||
top_p: float | None = None
|
||||
stream: bool = False
|
||||
client_params: dict[str, Any] | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
thinking: AnthropicThinkingConfig | None = None
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
tool_search: AnthropicToolSearchConfig | None = None
|
||||
is_claude_3: bool = False
|
||||
supports_tools: bool = True
|
||||
|
||||
Args:
|
||||
model: Anthropic model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
api_key: Anthropic API key (defaults to ANTHROPIC_API_KEY env var)
|
||||
base_url: Custom base URL for Anthropic API
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-1)
|
||||
max_tokens: Maximum tokens in response (required for Anthropic)
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
interceptor: HTTP interceptor for modifying requests/responses at transport level.
|
||||
response_format: Pydantic model for structured output. When provided, responses
|
||||
will be validated against this model schema.
|
||||
tool_search: Enable Anthropic's server-side tool search. When True, uses "bm25"
|
||||
variant by default. Pass an AnthropicToolSearchConfig to choose "regex" or
|
||||
"bm25". When enabled, tools are automatically marked with defer_loading=True
|
||||
and a tool search tool is injected into the tools list.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_previous_thinking_blocks: list[Any] = PrivateAttr(default_factory=list)
|
||||
|
||||
# Client params
|
||||
self.interceptor = interceptor
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_anthropic_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
# Anthropic uses stop_sequences; normalize from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
data["is_claude_3"] = "claude-3" in data.get("model", "").lower()
|
||||
# Normalize tool_search
|
||||
ts = data.get("tool_search")
|
||||
if ts is True:
|
||||
data["tool_search"] = AnthropicToolSearchConfig()
|
||||
elif ts is not None and not isinstance(ts, AnthropicToolSearchConfig):
|
||||
data["tool_search"] = None
|
||||
return data
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AnthropicCompletion:
|
||||
self._client = Anthropic(**self._get_client_params())
|
||||
|
||||
async_client_params = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -211,51 +196,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_params["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncAnthropic(**async_client_params)
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
self.response_format = response_format
|
||||
# Tool search config
|
||||
self.tool_search: AnthropicToolSearchConfig | None
|
||||
if tool_search is True:
|
||||
self.tool_search = AnthropicToolSearchConfig()
|
||||
elif isinstance(tool_search, AnthropicToolSearchConfig):
|
||||
self.tool_search = tool_search
|
||||
else:
|
||||
self.tool_search = None
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Anthropic API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
self._async_client = AsyncAnthropic(**async_client_params)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
@@ -751,11 +693,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
)
|
||||
elif isinstance(content, list):
|
||||
formatted_messages.append({"role": "assistant", "content": content})
|
||||
elif self.thinking and self.previous_thinking_blocks:
|
||||
elif self.thinking and self._previous_thinking_blocks:
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
*self.previous_thinking_blocks,
|
||||
*self._previous_thinking_blocks,
|
||||
{"type": "text", "text": content if content else ""},
|
||||
],
|
||||
)
|
||||
@@ -809,7 +751,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -843,11 +785,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = self.client.beta.messages.create(
|
||||
response = self._client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = self.client.messages.create(**params)
|
||||
response = self._client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -928,7 +870,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
self._emit_call_completed_event(
|
||||
@@ -952,7 +894,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -991,9 +933,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self.client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
self._client.beta.messages.stream(**stream_params, extra_body=extra_body)
|
||||
if betas
|
||||
else self.client.messages.stream(**stream_params)
|
||||
else self._client.messages.stream(**stream_params)
|
||||
)
|
||||
with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1072,7 +1014,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
@@ -1269,7 +1211,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
final_response: Message = self._client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
@@ -1288,7 +1230,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
self._previous_thinking_blocks = thinking_blocks
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
@@ -1330,7 +1272,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming async message completion."""
|
||||
uses_file_api = _contains_file_id_reference(params.get("messages", []))
|
||||
@@ -1364,11 +1306,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
try:
|
||||
if betas:
|
||||
params["betas"] = betas
|
||||
response = await self.async_client.beta.messages.create(
|
||||
response = await self._async_client.beta.messages.create(
|
||||
**params, extra_body=extra_body
|
||||
)
|
||||
else:
|
||||
response = await self.async_client.messages.create(**params)
|
||||
response = await self._async_client.messages.create(**params)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
@@ -1461,7 +1403,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
response_model: JsonResponseFormat | type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async streaming message completion."""
|
||||
betas: list[str] = []
|
||||
@@ -1498,11 +1440,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
current_tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
stream_context = (
|
||||
self.async_client.beta.messages.stream(
|
||||
self._async_client.beta.messages.stream(
|
||||
**stream_params, extra_body=extra_body
|
||||
)
|
||||
if betas
|
||||
else self.async_client.messages.stream(**stream_params)
|
||||
else self._async_client.messages.stream(**stream_params)
|
||||
)
|
||||
async with stream_context as stream:
|
||||
response_id = None
|
||||
@@ -1664,7 +1606,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
try:
|
||||
final_response: Message = await self.async_client.messages.create(
|
||||
final_response: Message = await self._async_client.messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
@@ -1786,8 +1728,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
from crewai_files.uploaders.anthropic import AnthropicFileUploader
|
||||
|
||||
return AnthropicFileUploader(
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -3,11 +3,13 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
from typing import Any, TypedDict
|
||||
from urllib.parse import urlparse
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -16,10 +18,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from azure.ai.inference import (
|
||||
ChatCompletionsClient,
|
||||
@@ -76,109 +74,84 @@ class AzureCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Azure authentication.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
api_key: str | None = None,
|
||||
endpoint: str | None = None,
|
||||
api_version: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
stop: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Azure AI Inference chat completion client.
|
||||
endpoint: str | None = None
|
||||
api_version: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
stream: bool = False
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_openai_model: bool = False
|
||||
is_azure_openai_endpoint: bool = False
|
||||
|
||||
Args:
|
||||
model: Azure deployment name or model name
|
||||
api_key: Azure API key (defaults to AZURE_API_KEY env var)
|
||||
endpoint: Azure endpoint URL (defaults to AZURE_ENDPOINT env var)
|
||||
api_version: Azure API version (defaults to AZURE_API_VERSION env var)
|
||||
timeout: Request timeout in seconds
|
||||
max_retries: Maximum number of retries
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
max_tokens: Maximum tokens in response
|
||||
stop: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
interceptor: HTTP interceptor (not yet supported for Azure).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
Only works with OpenAI models deployed on Azure.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_azure_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Azure AI Inference provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop or [], **kwargs
|
||||
)
|
||||
|
||||
self.api_key = api_key or os.getenv("AZURE_API_KEY")
|
||||
self.endpoint = (
|
||||
endpoint
|
||||
# Resolve env vars
|
||||
data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY")
|
||||
data["endpoint"] = (
|
||||
data.get("endpoint")
|
||||
or os.getenv("AZURE_ENDPOINT")
|
||||
or os.getenv("AZURE_OPENAI_ENDPOINT")
|
||||
or os.getenv("AZURE_API_BASE")
|
||||
)
|
||||
self.api_version = api_version or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
data["api_version"] = (
|
||||
data.get("api_version") or os.getenv("AZURE_API_VERSION") or "2024-06-01"
|
||||
)
|
||||
|
||||
if not self.api_key:
|
||||
if not data["api_key"]:
|
||||
raise ValueError(
|
||||
"Azure API key is required. Set AZURE_API_KEY environment variable or pass api_key parameter."
|
||||
)
|
||||
if not self.endpoint:
|
||||
if not data["endpoint"]:
|
||||
raise ValueError(
|
||||
"Azure endpoint is required. Set AZURE_ENDPOINT environment variable or pass endpoint parameter."
|
||||
)
|
||||
|
||||
# Validate and potentially fix Azure OpenAI endpoint URL
|
||||
self.endpoint = self._validate_and_fix_endpoint(self.endpoint, model)
|
||||
model = data.get("model", "")
|
||||
data["endpoint"] = AzureCompletion._validate_and_fix_endpoint(
|
||||
data["endpoint"], model
|
||||
)
|
||||
data["is_openai_model"] = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
parsed = urlparse(data["endpoint"])
|
||||
hostname = parsed.hostname or ""
|
||||
data["is_azure_openai_endpoint"] = (
|
||||
hostname == "openai.azure.com" or hostname.endswith(".openai.azure.com")
|
||||
) and "/openai/deployments/" in data["endpoint"]
|
||||
return data
|
||||
|
||||
# Build client kwargs
|
||||
client_kwargs = {
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> AzureCompletion:
|
||||
if not self.api_key:
|
||||
raise ValueError("Azure API key is required.")
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"endpoint": self.endpoint,
|
||||
"credential": AzureKeyCredential(self.api_key),
|
||||
}
|
||||
|
||||
# Add api_version if specified (primarily for Azure OpenAI endpoints)
|
||||
if self.api_version:
|
||||
client_kwargs["api_version"] = self.api_version
|
||||
|
||||
self.client = ChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.async_client = AsyncChatCompletionsClient(**client_kwargs) # type: ignore[arg-type]
|
||||
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
|
||||
self.is_openai_model = any(
|
||||
prefix in model.lower() for prefix in ["gpt-", "o1-", "text-"]
|
||||
)
|
||||
|
||||
self.is_azure_openai_endpoint = (
|
||||
"openai.azure.com" in self.endpoint
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
self._client = ChatCompletionsClient(**client_kwargs)
|
||||
self._async_client = AsyncChatCompletionsClient(**client_kwargs)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
@@ -215,7 +188,11 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Validated and potentially corrected endpoint URL
|
||||
"""
|
||||
if "openai.azure.com" in endpoint and "/openai/deployments/" not in endpoint:
|
||||
ep_host = urlparse(endpoint).hostname or ""
|
||||
is_azure_openai = ep_host == "openai.azure.com" or ep_host.endswith(
|
||||
".openai.azure.com"
|
||||
)
|
||||
if is_azure_openai and "/openai/deployments/" not in endpoint:
|
||||
endpoint = endpoint.rstrip("/")
|
||||
|
||||
if not endpoint.endswith("/openai/deployments"):
|
||||
@@ -731,7 +708,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
response: ChatCompletions = self._client.complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -926,7 +903,7 @@ class AzureCompletion(BaseLLM):
|
||||
tool_calls: dict[int, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
for update in self.client.complete(**params): # type: ignore[arg-type]
|
||||
for update in self._client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
@@ -967,7 +944,7 @@ class AzureCompletion(BaseLLM):
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
response: ChatCompletions = await self._async_client.complete(**params)
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
@@ -993,8 +970,8 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
|
||||
async for update in stream: # type: ignore[union-attr]
|
||||
stream = await self._async_client.complete(**params)
|
||||
async for update in stream:
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
usage = update.usage
|
||||
@@ -1110,8 +1087,8 @@ class AzureCompletion(BaseLLM):
|
||||
This ensures proper cleanup of the underlying aiohttp session
|
||||
to avoid unclosed connector warnings.
|
||||
"""
|
||||
if hasattr(self.async_client, "close"):
|
||||
await self.async_client.close()
|
||||
if hasattr(self._async_client, "close"):
|
||||
await self._async_client.close()
|
||||
|
||||
async def __aenter__(self) -> Self:
|
||||
"""Async context manager entry."""
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, TypedDict, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
from typing_extensions import Required
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
||||
ToolTypeDef,
|
||||
)
|
||||
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
@@ -228,129 +228,97 @@ class BedrockCompletion(BaseLLM):
|
||||
- Model-specific conversation format handling (e.g., Cohere requirements)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||
aws_access_key_id: str | None = None,
|
||||
aws_secret_access_key: str | None = None,
|
||||
aws_session_token: str | None = None,
|
||||
region_name: str | None = None,
|
||||
temperature: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
stop_sequences: Sequence[str] | None = None,
|
||||
stream: bool = False,
|
||||
guardrail_config: dict[str, Any] | None = None,
|
||||
additional_model_request_fields: dict[str, Any] | None = None,
|
||||
additional_model_response_field_paths: list[str] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize AWS Bedrock completion client.
|
||||
model: str = "anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
aws_access_key_id: str | None = None
|
||||
aws_secret_access_key: str | None = None
|
||||
aws_session_token: str | None = None
|
||||
region_name: str | None = None
|
||||
max_tokens: int | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
stream: bool = False
|
||||
guardrail_config: dict[str, Any] | None = None
|
||||
additional_model_request_fields: dict[str, Any] | None = None
|
||||
additional_model_response_field_paths: list[str] | None = None
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
response_format: type[BaseModel] | None = None
|
||||
is_claude_model: bool = False
|
||||
supports_tools: bool = True
|
||||
supports_streaming: bool = True
|
||||
model_id: str = ""
|
||||
|
||||
Args:
|
||||
model: The Bedrock model ID to use
|
||||
aws_access_key_id: AWS access key (defaults to environment variable)
|
||||
aws_secret_access_key: AWS secret key (defaults to environment variable)
|
||||
aws_session_token: AWS session token for temporary credentials
|
||||
region_name: AWS region name
|
||||
temperature: Sampling temperature for response generation
|
||||
max_tokens: Maximum tokens to generate
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter (Claude models only)
|
||||
stop_sequences: List of sequences that stop generation
|
||||
stream: Whether to use streaming responses
|
||||
guardrail_config: Guardrail configuration for content filtering
|
||||
additional_model_request_fields: Model-specific request parameters
|
||||
additional_model_response_field_paths: Custom response field paths
|
||||
interceptor: HTTP interceptor (not yet supported for Bedrock).
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_exit_stack: Any = PrivateAttr(default=None)
|
||||
_async_client_initialized: bool = PrivateAttr(default=False)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_bedrock_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for AWS Bedrock provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Extract provider from kwargs to avoid duplicate argument
|
||||
kwargs.pop("provider", None)
|
||||
# Force provider to bedrock
|
||||
data.pop("provider", None)
|
||||
data["provider"] = "bedrock"
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
stop=stop_sequences or [],
|
||||
provider="bedrock",
|
||||
**kwargs,
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
elif isinstance(seqs, Sequence) and not isinstance(seqs, list):
|
||||
seqs = list(seqs)
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv(
|
||||
"AWS_ACCESS_KEY_ID"
|
||||
)
|
||||
|
||||
# Configure client with timeouts and retries following AWS best practices
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={
|
||||
"max_attempts": 3,
|
||||
"mode": "adaptive",
|
||||
},
|
||||
tcp_keepalive=True,
|
||||
data["aws_secret_access_key"] = data.get("aws_secret_access_key") or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
|
||||
self.region_name = (
|
||||
region_name
|
||||
data["aws_session_token"] = data.get("aws_session_token") or os.getenv(
|
||||
"AWS_SESSION_TOKEN"
|
||||
)
|
||||
data["region_name"] = (
|
||||
data.get("region_name")
|
||||
or os.getenv("AWS_DEFAULT_REGION")
|
||||
or os.getenv("AWS_REGION_NAME")
|
||||
or "us-east-1"
|
||||
)
|
||||
|
||||
self.aws_access_key_id = aws_access_key_id or os.getenv("AWS_ACCESS_KEY_ID")
|
||||
self.aws_secret_access_key = aws_secret_access_key or os.getenv(
|
||||
"AWS_SECRET_ACCESS_KEY"
|
||||
)
|
||||
self.aws_session_token = aws_session_token or os.getenv("AWS_SESSION_TOKEN")
|
||||
model = data.get("model", "anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
data["is_claude_model"] = "claude" in model.lower()
|
||||
data["model_id"] = model
|
||||
return data
|
||||
|
||||
# Initialize Bedrock client with proper configuration
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> BedrockCompletion:
|
||||
config = Config(
|
||||
read_timeout=300,
|
||||
retries={"max_attempts": 3, "mode": "adaptive"},
|
||||
tcp_keepalive=True,
|
||||
)
|
||||
session = Session(
|
||||
aws_access_key_id=self.aws_access_key_id,
|
||||
aws_secret_access_key=self.aws_secret_access_key,
|
||||
aws_session_token=self.aws_session_token,
|
||||
region_name=self.region_name,
|
||||
)
|
||||
|
||||
self.client = session.client("bedrock-runtime", config=config)
|
||||
|
||||
self._client = session.client("bedrock-runtime", config=config)
|
||||
self._async_exit_stack = AsyncExitStack() if AIOBOTOCORE_AVAILABLE else None
|
||||
self._async_client_initialized = False
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences
|
||||
self.response_format = response_format
|
||||
|
||||
# Store advanced features (optional)
|
||||
self.guardrail_config = guardrail_config
|
||||
self.additional_model_request_fields = additional_model_request_fields
|
||||
self.additional_model_response_field_paths = (
|
||||
additional_model_response_field_paths
|
||||
)
|
||||
|
||||
# Model-specific settings
|
||||
self.is_claude_model = "claude" in model.lower()
|
||||
self.supports_tools = True # Converse API supports tools for most models
|
||||
self.supports_streaming = True
|
||||
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# NOTE: AWS credentials (access_key, secret_key, session_token) are
|
||||
# intentionally excluded — they must come from env on resume.
|
||||
if self.region_name and self.region_name != "us-east-1":
|
||||
config["region_name"] = self.region_name
|
||||
if self.max_tokens is not None:
|
||||
@@ -363,30 +331,6 @@ class BedrockCompletion(BaseLLM):
|
||||
config["guardrail_config"] = self.guardrail_config
|
||||
return config
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return [] if self.stop_sequences is None else list(self.stop_sequences)
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: Sequence[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Bedrock API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a Sequence, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, Sequence):
|
||||
self.stop_sequences = list(value)
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -710,7 +654,7 @@ class BedrockCompletion(BaseLLM):
|
||||
raise ValueError(f"Invalid message format at index {i}")
|
||||
|
||||
# Call Bedrock Converse API with proper error handling
|
||||
response = self.client.converse(
|
||||
response = self._client.converse(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
@@ -994,13 +938,13 @@ class BedrockCompletion(BaseLLM):
|
||||
accumulated_tool_input = ""
|
||||
|
||||
try:
|
||||
response = self.client.converse_stream(
|
||||
response = self._client.converse_stream(
|
||||
modelId=self.model_id,
|
||||
messages=cast(
|
||||
"Sequence[MessageTypeDef | MessageOutputTypeDef]",
|
||||
cast(object, messages),
|
||||
),
|
||||
**body, # type: ignore[arg-type]
|
||||
**body,
|
||||
)
|
||||
|
||||
stream = response.get("stream")
|
||||
|
||||
@@ -5,12 +5,13 @@ import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
@@ -19,10 +20,6 @@ from crewai.utilities.pydantic_schema_utils import generate_model_description
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
@@ -44,137 +41,84 @@ class GeminiCompletion(BaseLLM):
|
||||
offering native function calling, streaming support, and proper Gemini formatting.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-2.0-flash-001",
|
||||
api_key: str | None = None,
|
||||
project: str | None = None,
|
||||
location: str | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
top_k: int | None = None,
|
||||
max_output_tokens: int | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
safety_settings: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None,
|
||||
use_vertexai: bool | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
thinking_config: types.ThinkingConfig | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Google Gemini chat completion client.
|
||||
model: str = "gemini-2.0-flash-001"
|
||||
project: str | None = None
|
||||
location: str | None = None
|
||||
top_p: float | None = None
|
||||
top_k: int | None = None
|
||||
max_output_tokens: int | None = None
|
||||
stream: bool = False
|
||||
safety_settings: dict[str, Any] = Field(default_factory=dict)
|
||||
client_params: dict[str, Any] = Field(default_factory=dict)
|
||||
interceptor: BaseInterceptor[Any, Any] | None = None
|
||||
use_vertexai: bool = False
|
||||
response_format: type[BaseModel] | None = None
|
||||
thinking_config: Any = None
|
||||
tools: list[dict[str, Any]] | None = None
|
||||
supports_tools: bool = False
|
||||
is_gemini_2_0: bool = False
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-2.0-flash-001', 'gemini-1.5-pro')
|
||||
api_key: Google API key for Gemini API authentication.
|
||||
Defaults to GOOGLE_API_KEY or GEMINI_API_KEY env var.
|
||||
NOTE: Cannot be used with Vertex AI (project parameter). Use Gemini API instead.
|
||||
project: Google Cloud project ID for Vertex AI with ADC authentication.
|
||||
Requires Application Default Credentials (gcloud auth application-default login).
|
||||
NOTE: Vertex AI does NOT support API keys, only OAuth2/ADC.
|
||||
If both api_key and project are set, api_key takes precedence.
|
||||
location: Google Cloud location (for Vertex AI with ADC, defaults to 'us-central1')
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
top_k: Top-k sampling parameter
|
||||
max_output_tokens: Maximum tokens in response
|
||||
stop_sequences: Stop sequences
|
||||
stream: Enable streaming responses
|
||||
safety_settings: Safety filter settings
|
||||
client_params: Additional parameters to pass to the Google Gen AI Client constructor.
|
||||
Supports parameters like http_options, credentials, debug_config, etc.
|
||||
interceptor: HTTP interceptor (not yet supported for Gemini).
|
||||
use_vertexai: Whether to use Vertex AI instead of Gemini API.
|
||||
- True: Use Vertex AI (with ADC or Express mode with API key)
|
||||
- False: Use Gemini API (explicitly override env var)
|
||||
- None (default): Check GOOGLE_GENAI_USE_VERTEXAI env var
|
||||
When using Vertex AI with API key (Express mode), http_options with
|
||||
api_version="v1" is automatically configured.
|
||||
response_format: Pydantic model for structured output. Used as default when
|
||||
response_model is not passed to call()/acall() methods.
|
||||
thinking_config: ThinkingConfig for thinking models (gemini-2.5+, gemini-3+).
|
||||
Controls thought output via include_thoughts, thinking_budget,
|
||||
and thinking_level. When None, thinking models automatically
|
||||
get include_thoughts=True so thought content is surfaced.
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
if interceptor is not None:
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_gemini_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
if data.get("interceptor") is not None:
|
||||
raise NotImplementedError(
|
||||
"HTTP interceptors are not yet supported for Google Gemini provider. "
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = (
|
||||
data.get("api_key")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
data["project"] = data.get("project") or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
data["location"] = (
|
||||
data.get("location") or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
)
|
||||
|
||||
# Store client params for later use
|
||||
self.client_params = client_params or {}
|
||||
|
||||
# Get API configuration with environment variable fallbacks
|
||||
self.api_key = (
|
||||
api_key or os.getenv("GOOGLE_API_KEY") or os.getenv("GEMINI_API_KEY")
|
||||
)
|
||||
self.project = project or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location or os.getenv("GOOGLE_CLOUD_LOCATION") or "us-central1"
|
||||
|
||||
if use_vertexai is None:
|
||||
use_vertexai = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
|
||||
self.client = self._initialize_client(use_vertexai)
|
||||
|
||||
# Store completion parameters
|
||||
self.top_p = top_p
|
||||
self.top_k = top_k
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
self.response_format = response_format
|
||||
use_vx = data.get("use_vertexai")
|
||||
if use_vx is None:
|
||||
use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
data["use_vertexai"] = use_vx
|
||||
|
||||
# Model-specific settings
|
||||
model = data.get("model", "gemini-2.0-flash-001")
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
self.supports_tools = bool(
|
||||
data["supports_tools"] = bool(
|
||||
version_match and float(version_match.group(1)) >= 1.5
|
||||
)
|
||||
self.is_gemini_2_0 = bool(
|
||||
data["is_gemini_2_0"] = bool(
|
||||
version_match and float(version_match.group(1)) >= 2.0
|
||||
)
|
||||
|
||||
self.thinking_config = thinking_config
|
||||
# Auto-enable thinking for gemini-2.5+
|
||||
if (
|
||||
self.thinking_config is None
|
||||
data.get("thinking_config") is None
|
||||
and version_match
|
||||
and float(version_match.group(1)) >= 2.5
|
||||
):
|
||||
self.thinking_config = types.ThinkingConfig(include_thoughts=True)
|
||||
data["thinking_config"] = types.ThinkingConfig(include_thoughts=True)
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
return self.stop_sequences
|
||||
return data
|
||||
|
||||
@stop.setter
|
||||
def stop(self, value: list[str] | str | None) -> None:
|
||||
"""Set stop sequences.
|
||||
|
||||
Synchronizes stop_sequences to ensure values set by CrewAgentExecutor
|
||||
are properly sent to the Gemini API.
|
||||
|
||||
Args:
|
||||
value: Stop sequences as a list, single string, or None
|
||||
"""
|
||||
if value is None:
|
||||
self.stop_sequences = []
|
||||
elif isinstance(value, str):
|
||||
self.stop_sequences = [value]
|
||||
elif isinstance(value, list):
|
||||
self.stop_sequences = value
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
@model_validator(mode="after")
|
||||
def _init_client(self) -> GeminiCompletion:
|
||||
self._client = self._initialize_client(self.use_vertexai)
|
||||
return self
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
@@ -283,8 +227,8 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
if (
|
||||
hasattr(self, "client")
|
||||
and hasattr(self.client, "vertexai")
|
||||
and self.client.vertexai
|
||||
and hasattr(self._client, "vertexai")
|
||||
and self._client.vertexai
|
||||
):
|
||||
# Vertex AI configuration
|
||||
params.update(
|
||||
@@ -1152,7 +1096,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
response = self._client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1192,7 +1136,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
for chunk in self._client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1230,7 +1174,7 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
response = await self._client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1270,7 +1214,7 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
stream = await self._client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
@@ -1474,6 +1418,6 @@ class GeminiCompletion(BaseLLM):
|
||||
try:
|
||||
from crewai_files.uploaders.gemini import GeminiFileUploader
|
||||
|
||||
return GeminiFileUploader(client=self.client)
|
||||
return GeminiFileUploader(client=self._client)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -14,10 +14,11 @@ from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
from openai.types.responses import Response
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, PrivateAttr, model_validator
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
from crewai.llms.base_llm import BaseLLM, llm_call_context
|
||||
from crewai.llms.base_llm import BaseLLM, JsonResponseFormat, llm_call_context
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.hooks.transport import AsyncHTTPTransport, HTTPTransport
|
||||
from crewai.utilities.agent_utils import is_context_length_exceeded
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -29,7 +30,6 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
@@ -183,77 +183,69 @@ class OpenAICompletion(BaseLLM):
|
||||
"computer_use": "computer_use_preview",
|
||||
}
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4o",
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
organization: str | None = None,
|
||||
project: str | None = None,
|
||||
timeout: float | None = None,
|
||||
max_retries: int = 2,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
default_query: dict[str, Any] | None = None,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
temperature: float | None = None,
|
||||
top_p: float | None = None,
|
||||
frequency_penalty: float | None = None,
|
||||
presence_penalty: float | None = None,
|
||||
max_tokens: int | None = None,
|
||||
max_completion_tokens: int | None = None,
|
||||
seed: int | None = None,
|
||||
stream: bool = False,
|
||||
response_format: dict[str, Any] | type[BaseModel] | None = None,
|
||||
logprobs: bool | None = None,
|
||||
top_logprobs: int | None = None,
|
||||
reasoning_effort: str | None = None,
|
||||
provider: str | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
api: Literal["completions", "responses"] = "completions",
|
||||
instructions: str | None = None,
|
||||
store: bool | None = None,
|
||||
previous_response_id: str | None = None,
|
||||
include: list[str] | None = None,
|
||||
builtin_tools: list[str] | None = None,
|
||||
parse_tool_outputs: bool = False,
|
||||
auto_chain: bool = False,
|
||||
auto_chain_reasoning: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI completion client."""
|
||||
model: str = "gpt-4o"
|
||||
organization: str | None = None
|
||||
project: str | None = None
|
||||
timeout: float | None = None
|
||||
max_retries: int = 2
|
||||
default_headers: dict[str, str] | None = None
|
||||
default_query: dict[str, Any] | None = None
|
||||
client_params: dict[str, Any] | None = None
|
||||
top_p: float | None = None
|
||||
frequency_penalty: float | None = None
|
||||
presence_penalty: float | None = None
|
||||
max_tokens: int | None = None
|
||||
max_completion_tokens: int | None = None
|
||||
seed: int | None = None
|
||||
stream: bool = False
|
||||
response_format: JsonResponseFormat | type[BaseModel] | None = None
|
||||
logprobs: bool | None = None
|
||||
top_logprobs: int | None = None
|
||||
reasoning_effort: str | None = None
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None
|
||||
api: Literal["completions", "responses"] = "completions"
|
||||
instructions: str | None = None
|
||||
store: bool | None = None
|
||||
previous_response_id: str | None = None
|
||||
include: list[str] | None = None
|
||||
builtin_tools: list[str] | None = None
|
||||
parse_tool_outputs: bool = False
|
||||
auto_chain: bool = False
|
||||
auto_chain_reasoning: bool = False
|
||||
api_base: str | None = None
|
||||
is_o1_model: bool = False
|
||||
is_gpt4_model: bool = False
|
||||
|
||||
if provider is None:
|
||||
provider = kwargs.pop("provider", "openai")
|
||||
_client: Any = PrivateAttr(default=None)
|
||||
_async_client: Any = PrivateAttr(default=None)
|
||||
_last_response_id: str | None = PrivateAttr(default=None)
|
||||
_last_reasoning_items: list[Any] | None = PrivateAttr(default=None)
|
||||
|
||||
self.interceptor = interceptor
|
||||
# Client configuration attributes
|
||||
self.organization = organization
|
||||
self.project = project
|
||||
self.max_retries = max_retries
|
||||
self.default_headers = default_headers
|
||||
self.default_query = default_query
|
||||
self.client_params = client_params
|
||||
self.timeout = timeout
|
||||
self.base_url = base_url
|
||||
self.api_base = kwargs.pop("api_base", None)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
api_key=api_key or os.getenv("OPENAI_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
provider=provider,
|
||||
**kwargs,
|
||||
)
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _normalize_openai_fields(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY")
|
||||
# Extract api_base from kwargs if present
|
||||
if "api_base" not in data:
|
||||
data["api_base"] = None
|
||||
model = data.get("model", "gpt-4o")
|
||||
data["is_o1_model"] = "o1" in model.lower()
|
||||
data["is_gpt4_model"] = "gpt-4" in model.lower()
|
||||
return data
|
||||
|
||||
@model_validator(mode="after")
|
||||
def _init_clients(self) -> OpenAICompletion:
|
||||
client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
transport = HTTPTransport(interceptor=self.interceptor)
|
||||
http_client = httpx.Client(transport=transport)
|
||||
client_config["http_client"] = http_client
|
||||
|
||||
self.client = OpenAI(**client_config)
|
||||
self._client = OpenAI(**client_config)
|
||||
|
||||
async_client_config = self._get_client_params()
|
||||
if self.interceptor:
|
||||
@@ -261,35 +253,8 @@ class OpenAICompletion(BaseLLM):
|
||||
async_http_client = httpx.AsyncClient(transport=async_transport)
|
||||
async_client_config["http_client"] = async_http_client
|
||||
|
||||
self.async_client = AsyncOpenAI(**async_client_config)
|
||||
|
||||
# Completion parameters
|
||||
self.top_p = top_p
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.presence_penalty = presence_penalty
|
||||
self.max_tokens = max_tokens
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.seed = seed
|
||||
self.stream = stream
|
||||
self.response_format = response_format
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.is_o1_model = "o1" in model.lower()
|
||||
self.is_gpt4_model = "gpt-4" in model.lower()
|
||||
|
||||
# API selection and Responses API parameters
|
||||
self.api = api
|
||||
self.instructions = instructions
|
||||
self.store = store
|
||||
self.previous_response_id = previous_response_id
|
||||
self.include = include
|
||||
self.builtin_tools = builtin_tools
|
||||
self.parse_tool_outputs = parse_tool_outputs
|
||||
self.auto_chain = auto_chain
|
||||
self.auto_chain_reasoning = auto_chain_reasoning
|
||||
self._last_response_id: str | None = None
|
||||
self._last_reasoning_items: list[Any] | None = None
|
||||
self._async_client = AsyncOpenAI(**async_client_config)
|
||||
return self
|
||||
|
||||
@property
|
||||
def last_response_id(self) -> str | None:
|
||||
@@ -818,7 +783,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = self.client.responses.create(**params)
|
||||
response: Response = self._client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -950,7 +915,7 @@ class OpenAICompletion(BaseLLM):
|
||||
) -> str | ResponsesAPIResult | Any:
|
||||
"""Handle async non-streaming Responses API call."""
|
||||
try:
|
||||
response: Response = await self.async_client.responses.create(**params)
|
||||
response: Response = await self._async_client.responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
@@ -1081,7 +1046,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = self.client.responses.create(**params)
|
||||
stream = self._client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
for event in stream:
|
||||
@@ -1205,7 +1170,7 @@ class OpenAICompletion(BaseLLM):
|
||||
function_calls: list[dict[str, Any]] = []
|
||||
final_response: Response | None = None
|
||||
|
||||
stream = await self.async_client.responses.create(**params)
|
||||
stream = await self._async_client.responses.create(**params)
|
||||
response_id_stream = None
|
||||
|
||||
async for event in stream:
|
||||
@@ -1595,7 +1560,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = self.client.beta.chat.completions.parse(
|
||||
parsed_response = self._client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1618,7 +1583,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = self.client.chat.completions.create(**params)
|
||||
response: ChatCompletion = self._client.chat.completions.create(**params)
|
||||
|
||||
usage = self._extract_openai_token_usage(response)
|
||||
|
||||
@@ -1837,7 +1802,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self.client.beta.chat.completions.stream(
|
||||
with self._client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
@@ -1873,7 +1838,7 @@ class OpenAICompletion(BaseLLM):
|
||||
return ""
|
||||
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
self._client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -1970,7 +1935,7 @@ class OpenAICompletion(BaseLLM):
|
||||
parse_params = {
|
||||
k: v for k, v in params.items() if k != "response_format"
|
||||
}
|
||||
parsed_response = await self.async_client.beta.chat.completions.parse(
|
||||
parsed_response = await self._async_client.beta.chat.completions.parse(
|
||||
**parse_params,
|
||||
response_format=response_model,
|
||||
)
|
||||
@@ -1993,7 +1958,7 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return parsed_object
|
||||
|
||||
response: ChatCompletion = await self.async_client.chat.completions.create(
|
||||
response: ChatCompletion = await self._async_client.chat.completions.create(
|
||||
**params
|
||||
)
|
||||
|
||||
@@ -2111,7 +2076,7 @@ class OpenAICompletion(BaseLLM):
|
||||
if response_model:
|
||||
completion_stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
@@ -2164,7 +2129,7 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
stream: AsyncIterator[
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
] = await self._async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
@@ -2245,6 +2210,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the model supports stop words."""
|
||||
model_lower = self.model.lower() if self.model else ""
|
||||
if "gpt-5" in model_lower:
|
||||
return False
|
||||
return not self.is_o1_model
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
@@ -2353,8 +2321,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from crewai_files.uploaders.openai import OpenAIFileUploader
|
||||
|
||||
return OpenAIFileUploader(
|
||||
client=self.client,
|
||||
async_client=self.async_client,
|
||||
client=self._client,
|
||||
async_client=self._async_client,
|
||||
)
|
||||
except ImportError:
|
||||
return None
|
||||
|
||||
@@ -16,6 +16,8 @@ from dataclasses import dataclass, field
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from pydantic import model_validator
|
||||
|
||||
from crewai.llms.providers.openai.completion import OpenAICompletion
|
||||
|
||||
|
||||
@@ -140,31 +142,13 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
provider: str,
|
||||
api_key: str | None = None,
|
||||
base_url: str | None = None,
|
||||
default_headers: dict[str, str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize OpenAI-compatible completion client.
|
||||
@model_validator(mode="before")
|
||||
@classmethod
|
||||
def _resolve_provider_config(cls, data: Any) -> Any:
|
||||
if not isinstance(data, dict):
|
||||
return data
|
||||
|
||||
Args:
|
||||
model: The model identifier.
|
||||
provider: The provider name (must be in OPENAI_COMPATIBLE_PROVIDERS).
|
||||
api_key: Optional API key override. If not provided, uses the
|
||||
provider's configured environment variable.
|
||||
base_url: Optional base URL override. If not provided, uses the
|
||||
provider's configured default or environment variable.
|
||||
default_headers: Optional headers to merge with provider defaults.
|
||||
**kwargs: Additional arguments passed to OpenAICompletion.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provider is not supported or required API key
|
||||
is missing.
|
||||
"""
|
||||
provider = data.get("provider", "")
|
||||
config = OPENAI_COMPATIBLE_PROVIDERS.get(provider)
|
||||
if config is None:
|
||||
supported = ", ".join(sorted(OPENAI_COMPATIBLE_PROVIDERS.keys()))
|
||||
@@ -173,21 +157,15 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
f"Supported providers: {supported}"
|
||||
)
|
||||
|
||||
resolved_api_key = self._resolve_api_key(api_key, config, provider)
|
||||
resolved_base_url = self._resolve_base_url(base_url, config, provider)
|
||||
resolved_headers = self._resolve_headers(default_headers, config)
|
||||
|
||||
super().__init__(
|
||||
model=model,
|
||||
provider=provider,
|
||||
api_key=resolved_api_key,
|
||||
base_url=resolved_base_url,
|
||||
default_headers=resolved_headers,
|
||||
**kwargs,
|
||||
data["api_key"] = cls._resolve_api_key(data.get("api_key"), config, provider)
|
||||
data["base_url"] = cls._resolve_base_url(data.get("base_url"), config, provider)
|
||||
data["default_headers"] = cls._resolve_headers(
|
||||
data.get("default_headers"), config
|
||||
)
|
||||
return data
|
||||
|
||||
@staticmethod
|
||||
def _resolve_api_key(
|
||||
self,
|
||||
api_key: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -220,8 +198,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return config.default_api_key
|
||||
|
||||
@staticmethod
|
||||
def _resolve_base_url(
|
||||
self,
|
||||
base_url: str | None,
|
||||
config: ProviderConfig,
|
||||
provider: str,
|
||||
@@ -249,8 +227,8 @@ class OpenAICompatibleCompletion(OpenAICompletion):
|
||||
|
||||
return resolved
|
||||
|
||||
@staticmethod
|
||||
def _resolve_headers(
|
||||
self,
|
||||
headers: dict[str, str] | None,
|
||||
config: ProviderConfig,
|
||||
) -> dict[str, str] | None:
|
||||
|
||||
@@ -1 +0,0 @@
|
||||
"""Third-party LLM implementations for crewAI."""
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
from collections import defaultdict
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from pydantic import BaseModel, Field, InstanceOf
|
||||
from pydantic import BaseModel, Field
|
||||
from rich.box import HEAVY_EDGE
|
||||
from rich.console import Console
|
||||
from rich.table import Table
|
||||
@@ -39,9 +39,9 @@ class CrewEvaluator:
|
||||
def __init__(
|
||||
self,
|
||||
crew: Crew,
|
||||
eval_llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
eval_llm: BaseLLM | str | None = None,
|
||||
openai_model_name: str | None = None,
|
||||
llm: InstanceOf[BaseLLM] | str | None = None,
|
||||
llm: BaseLLM | str | None = None,
|
||||
) -> None:
|
||||
self.crew = crew
|
||||
self.llm = eval_llm
|
||||
|
||||
@@ -1692,9 +1692,27 @@ def test_agent_with_knowledge_sources_works_with_copy():
|
||||
) as mock_knowledge_storage:
|
||||
from crewai.knowledge.storage.base_knowledge_storage import BaseKnowledgeStorage
|
||||
|
||||
mock_knowledge_storage_instance = mock_knowledge_storage.return_value
|
||||
mock_knowledge_storage_instance.__class__ = BaseKnowledgeStorage
|
||||
agent.knowledge_storage = mock_knowledge_storage_instance
|
||||
class _StubStorage(BaseKnowledgeStorage):
|
||||
def search(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
async def asearch(self, query, limit=5, metadata_filter=None, score_threshold=0.6):
|
||||
return []
|
||||
|
||||
def save(self, documents):
|
||||
pass
|
||||
|
||||
async def asave(self, documents):
|
||||
pass
|
||||
|
||||
def reset(self):
|
||||
pass
|
||||
|
||||
async def areset(self):
|
||||
pass
|
||||
|
||||
mock_knowledge_storage.return_value = _StubStorage()
|
||||
agent.knowledge_storage = _StubStorage()
|
||||
|
||||
agent_copy = agent.copy()
|
||||
|
||||
|
||||
@@ -879,30 +879,6 @@ class TestNativeToolExecution:
|
||||
assert len(tool_messages) == 1
|
||||
assert tool_messages[0]["tool_call_id"] == "call_1"
|
||||
|
||||
def test_check_native_todo_completion_requires_current_todo(
|
||||
self, mock_dependencies
|
||||
):
|
||||
from crewai.utilities.planning_types import TodoList
|
||||
|
||||
executor = AgentExecutor(**mock_dependencies)
|
||||
|
||||
# No current todo → not satisfied
|
||||
executor.state.todos = TodoList(items=[])
|
||||
assert executor.check_native_todo_completion() == "todo_not_satisfied"
|
||||
|
||||
# With a current todo that has tool_to_use → satisfied
|
||||
running = TodoItem(
|
||||
step_number=1,
|
||||
description="Use the expected tool",
|
||||
tool_to_use="expected_tool",
|
||||
status="running",
|
||||
)
|
||||
executor.state.todos = TodoList(items=[running])
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
# With a current todo without tool_to_use → still satisfied
|
||||
running.tool_to_use = None
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
|
||||
class TestPlannerObserver:
|
||||
|
||||
@@ -1,7 +1,11 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -68,13 +72,9 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Tokyo?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -87,7 +87,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6158'
|
||||
- '6065'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
@@ -109,26 +109,113 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/responses
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXQCgT3p3ViImkiqDiZGqLREQtp\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747248,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_9ZqMavn3J1fBnQEaqpYol0Bd\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Tokyo\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
string: "{\n \"id\": \"resp_0d68149bcc0d14810069caf464a4b48197bd9f098abb2f6303\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908516,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908517,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0d68149bcc0d14810069caf46568088197a33be67f16a1fa09\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Tokyo\\\"}\",\n \"call_id\": \"call_74rwmYse0DE4JFaFGyAFx9bu\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 0\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -137,7 +224,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 10 Feb 2026 18:14:08 GMT
|
||||
- Mon, 30 Mar 2026 22:08:37 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -146,8 +233,6 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -155,15 +240,13 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '484'
|
||||
- '1085'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
@@ -182,8 +265,12 @@ interactions:
|
||||
code: 200
|
||||
message: OK
|
||||
- request:
|
||||
body: '{"messages":[{"role":"system","content":"You are a helpful assistant that
|
||||
uses tools. This is padding text to ensure the prompt is large enough for caching.
|
||||
body: '{"input":[{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","instructions":"You
|
||||
are a helpful assistant that uses tools. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
@@ -250,13 +337,9 @@ interactions:
|
||||
for caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. "},{"role":"user","content":"What is the weather in Paris?"}],"model":"gpt-4.1","tool_choice":"auto","tools":[{"type":"function","function":{"name":"get_weather","description":"Get
|
||||
the current weather for a location","strict":true,"parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"],"additionalProperties":false}}}]}'
|
||||
text to ensure the prompt is large enough for caching. ","tools":[{"type":"function","name":"get_weather","description":"Get
|
||||
the current weather for a location","parameters":{"type":"object","properties":{"location":{"type":"string","description":"The
|
||||
city name"}},"required":["location"]}}]}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
@@ -269,7 +352,7 @@ interactions:
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '6158'
|
||||
- '6065'
|
||||
content-type:
|
||||
- application/json
|
||||
cookie:
|
||||
@@ -293,26 +376,113 @@ interactions:
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.3
|
||||
- 3.13.12
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
uri: https://api.openai.com/v1/responses
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-D7mXR8k9vk8TlGvGXlrQSI7iNeAN1\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1770747249,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": null,\n \"tool_calls\": [\n {\n
|
||||
\ \"id\": \"call_6PeUBlRPG8JcV2lspmLjJbnn\",\n \"type\":
|
||||
\"function\",\n \"function\": {\n \"name\": \"get_weather\",\n
|
||||
\ \"arguments\": \"{\\\"location\\\":\\\"Paris\\\"}\"\n }\n
|
||||
\ }\n ],\n \"refusal\": null,\n \"annotations\":
|
||||
[]\n },\n \"logprobs\": null,\n \"finish_reason\": \"tool_calls\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 1187,\n \"completion_tokens\":
|
||||
14,\n \"total_tokens\": 1201,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
1152,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": \"fp_8b22347a3e\"\n}\n"
|
||||
string: "{\n \"id\": \"resp_0525bf798202137e0069caf465ee3c8196aa7c83da1c369eb7\",\n
|
||||
\ \"object\": \"response\",\n \"created_at\": 1774908517,\n \"status\":
|
||||
\"completed\",\n \"background\": false,\n \"billing\": {\n \"payer\":
|
||||
\"developer\"\n },\n \"completed_at\": 1774908518,\n \"error\": null,\n
|
||||
\ \"frequency_penalty\": 0.0,\n \"incomplete_details\": null,\n \"instructions\":
|
||||
\"You are a helpful assistant that uses tools. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. This is
|
||||
padding text to ensure the prompt is large enough for caching. This is padding
|
||||
text to ensure the prompt is large enough for caching. This is padding text
|
||||
to ensure the prompt is large enough for caching. This is padding text to
|
||||
ensure the prompt is large enough for caching. This is padding text to ensure
|
||||
the prompt is large enough for caching. This is padding text to ensure the
|
||||
prompt is large enough for caching. This is padding text to ensure the prompt
|
||||
is large enough for caching. This is padding text to ensure the prompt is
|
||||
large enough for caching. This is padding text to ensure the prompt is large
|
||||
enough for caching. This is padding text to ensure the prompt is large enough
|
||||
for caching. This is padding text to ensure the prompt is large enough for
|
||||
caching. This is padding text to ensure the prompt is large enough for caching.
|
||||
This is padding text to ensure the prompt is large enough for caching. This
|
||||
is padding text to ensure the prompt is large enough for caching. \",\n \"max_output_tokens\":
|
||||
null,\n \"max_tool_calls\": null,\n \"model\": \"gpt-4.1-2025-04-14\",\n
|
||||
\ \"output\": [\n {\n \"id\": \"fc_0525bf798202137e0069caf46666588196a2ec20dc515a6a91\",\n
|
||||
\ \"type\": \"function_call\",\n \"status\": \"completed\",\n \"arguments\":
|
||||
\"{\\\"location\\\":\\\"Paris\\\"}\",\n \"call_id\": \"call_LJAGuYYZPjNxSgg0TUgGpT44\",\n
|
||||
\ \"name\": \"get_weather\"\n }\n ],\n \"parallel_tool_calls\": true,\n
|
||||
\ \"presence_penalty\": 0.0,\n \"previous_response_id\": null,\n \"prompt_cache_key\":
|
||||
null,\n \"prompt_cache_retention\": null,\n \"reasoning\": {\n \"effort\":
|
||||
null,\n \"summary\": null\n },\n \"safety_identifier\": null,\n \"service_tier\":
|
||||
\"default\",\n \"store\": true,\n \"temperature\": 1.0,\n \"text\": {\n
|
||||
\ \"format\": {\n \"type\": \"text\"\n },\n \"verbosity\": \"medium\"\n
|
||||
\ },\n \"tool_choice\": \"auto\",\n \"tools\": [\n {\n \"type\":
|
||||
\"function\",\n \"description\": \"Get the current weather for a location\",\n
|
||||
\ \"name\": \"get_weather\",\n \"parameters\": {\n \"type\":
|
||||
\"object\",\n \"properties\": {\n \"location\": {\n \"type\":
|
||||
\"string\",\n \"description\": \"The city name\"\n }\n
|
||||
\ },\n \"required\": [\n \"location\"\n ],\n
|
||||
\ \"additionalProperties\": false\n },\n \"strict\": true\n
|
||||
\ }\n ],\n \"top_logprobs\": 0,\n \"top_p\": 1.0,\n \"truncation\":
|
||||
\"disabled\",\n \"usage\": {\n \"input_tokens\": 1185,\n \"input_tokens_details\":
|
||||
{\n \"cached_tokens\": 1152\n },\n \"output_tokens\": 15,\n \"output_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0\n },\n \"total_tokens\": 1200\n },\n
|
||||
\ \"user\": null,\n \"metadata\": {}\n}"
|
||||
headers:
|
||||
CF-RAY:
|
||||
- CF-RAY-XXX
|
||||
@@ -321,7 +491,7 @@ interactions:
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Tue, 10 Feb 2026 18:14:09 GMT
|
||||
- Mon, 30 Mar 2026 22:08:38 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
@@ -330,8 +500,6 @@ interactions:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
cf-cache-status:
|
||||
@@ -339,15 +507,11 @@ interactions:
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '528'
|
||||
- '653'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
|
||||
@@ -0,0 +1,110 @@
|
||||
interactions:
|
||||
- request:
|
||||
body: '{"messages":[{"role":"user","content":"What is the capital of France?"}],"model":"gpt-5"}'
|
||||
headers:
|
||||
User-Agent:
|
||||
- X-USER-AGENT-XXX
|
||||
accept:
|
||||
- application/json
|
||||
accept-encoding:
|
||||
- ACCEPT-ENCODING-XXX
|
||||
authorization:
|
||||
- AUTHORIZATION-XXX
|
||||
connection:
|
||||
- keep-alive
|
||||
content-length:
|
||||
- '89'
|
||||
content-type:
|
||||
- application/json
|
||||
host:
|
||||
- api.openai.com
|
||||
x-stainless-arch:
|
||||
- X-STAINLESS-ARCH-XXX
|
||||
x-stainless-async:
|
||||
- 'false'
|
||||
x-stainless-lang:
|
||||
- python
|
||||
x-stainless-os:
|
||||
- X-STAINLESS-OS-XXX
|
||||
x-stainless-package-version:
|
||||
- 1.83.0
|
||||
x-stainless-raw-response:
|
||||
- 'true'
|
||||
x-stainless-read-timeout:
|
||||
- X-STAINLESS-READ-TIMEOUT-XXX
|
||||
x-stainless-retry-count:
|
||||
- '0'
|
||||
x-stainless-runtime:
|
||||
- CPython
|
||||
x-stainless-runtime-version:
|
||||
- 3.13.2
|
||||
method: POST
|
||||
uri: https://api.openai.com/v1/chat/completions
|
||||
response:
|
||||
body:
|
||||
string: "{\n \"id\": \"chatcmpl-DO4LcSpy72yIXCYSIVOQEXWNXydgn\",\n \"object\":
|
||||
\"chat.completion\",\n \"created\": 1774628956,\n \"model\": \"gpt-5-2025-08-07\",\n
|
||||
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
|
||||
\"assistant\",\n \"content\": \"Paris.\",\n \"refusal\": null,\n
|
||||
\ \"annotations\": []\n },\n \"finish_reason\": \"stop\"\n
|
||||
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 13,\n \"completion_tokens\":
|
||||
11,\n \"total_tokens\": 24,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
|
||||
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\":
|
||||
{\n \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
|
||||
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
|
||||
\"default\",\n \"system_fingerprint\": null\n}\n"
|
||||
headers:
|
||||
CF-Cache-Status:
|
||||
- DYNAMIC
|
||||
CF-Ray:
|
||||
- 9e2fc5dce85582fb-GIG
|
||||
Connection:
|
||||
- keep-alive
|
||||
Content-Type:
|
||||
- application/json
|
||||
Date:
|
||||
- Fri, 27 Mar 2026 16:29:17 GMT
|
||||
Server:
|
||||
- cloudflare
|
||||
Strict-Transport-Security:
|
||||
- STS-XXX
|
||||
Transfer-Encoding:
|
||||
- chunked
|
||||
X-Content-Type-Options:
|
||||
- X-CONTENT-TYPE-XXX
|
||||
access-control-expose-headers:
|
||||
- ACCESS-CONTROL-XXX
|
||||
alt-svc:
|
||||
- h3=":443"; ma=86400
|
||||
content-length:
|
||||
- '772'
|
||||
openai-organization:
|
||||
- OPENAI-ORG-XXX
|
||||
openai-processing-ms:
|
||||
- '1343'
|
||||
openai-project:
|
||||
- OPENAI-PROJECT-XXX
|
||||
openai-version:
|
||||
- '2020-10-01'
|
||||
set-cookie:
|
||||
- SET-COOKIE-XXX
|
||||
x-openai-proxy-wasm:
|
||||
- v0.1
|
||||
x-ratelimit-limit-requests:
|
||||
- X-RATELIMIT-LIMIT-REQUESTS-XXX
|
||||
x-ratelimit-limit-tokens:
|
||||
- X-RATELIMIT-LIMIT-TOKENS-XXX
|
||||
x-ratelimit-remaining-requests:
|
||||
- X-RATELIMIT-REMAINING-REQUESTS-XXX
|
||||
x-ratelimit-remaining-tokens:
|
||||
- X-RATELIMIT-REMAINING-TOKENS-XXX
|
||||
x-ratelimit-reset-requests:
|
||||
- X-RATELIMIT-RESET-REQUESTS-XXX
|
||||
x-ratelimit-reset-tokens:
|
||||
- X-RATELIMIT-RESET-TOKENS-XXX
|
||||
x-request-id:
|
||||
- X-REQUEST-ID-XXX
|
||||
status:
|
||||
code: 200
|
||||
message: OK
|
||||
version: 1
|
||||
@@ -136,6 +136,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
@@ -173,6 +174,7 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
|
||||
self.assert_request_with_org_id(
|
||||
@@ -201,6 +203,48 @@ class TestPlusAPI(unittest.TestCase):
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": None,
|
||||
"tools_metadata": None,
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
)
|
||||
self.assertEqual(response, mock_response)
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI._make_request")
|
||||
def test_publish_tool_with_tools_metadata(self, mock_make_request):
|
||||
mock_response = MagicMock()
|
||||
mock_make_request.return_value = mock_response
|
||||
handle = "test_tool_handle"
|
||||
public = True
|
||||
version = "1.0.0"
|
||||
description = "Test tool description"
|
||||
encoded_file = "encoded_test_file"
|
||||
available_exports = [{"name": "MyTool"}]
|
||||
tools_metadata = [
|
||||
{
|
||||
"name": "MyTool",
|
||||
"humanized_name": "my_tool",
|
||||
"description": "A test tool",
|
||||
"run_params_schema": {"type": "object", "properties": {}},
|
||||
"init_params_schema": {"type": "object", "properties": {}},
|
||||
"env_vars": [{"name": "API_KEY", "description": "API key", "required": True, "default": None}],
|
||||
}
|
||||
]
|
||||
|
||||
response = self.api.publish_tool(
|
||||
handle, public, version, description, encoded_file,
|
||||
available_exports=available_exports,
|
||||
tools_metadata=tools_metadata,
|
||||
)
|
||||
|
||||
params = {
|
||||
"handle": handle,
|
||||
"public": public,
|
||||
"version": version,
|
||||
"file": encoded_file,
|
||||
"description": description,
|
||||
"available_exports": available_exports,
|
||||
"tools_metadata": {"package": handle, "tools": tools_metadata},
|
||||
}
|
||||
mock_make_request.assert_called_once_with(
|
||||
"POST", "/crewai_plus/api/v1/tools", json=params
|
||||
|
||||
@@ -363,3 +363,290 @@ def test_get_crews_ignores_template_directories(
|
||||
utils.get_crews()
|
||||
|
||||
assert not template_crew_detected
|
||||
|
||||
|
||||
# Tests for extract_tools_metadata
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_project(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty project."""
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when no __init__.py exists."""
|
||||
(temp_project_dir / "some_file.py").write_text("print('hello')")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_init_file(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list for empty __init__.py."""
|
||||
create_init_file(temp_project_dir, "")
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_no_all_variable(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list when __all__ is not defined."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"from crewai.tools import BaseTool\n\nclass MyTool(BaseTool):\n pass",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_valid_base_tool_class(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from a valid BaseTool class."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
assert metadata[0]["humanized_name"] == "my_tool"
|
||||
assert metadata[0]["description"] == "A test tool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_args_schema(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts run_params_schema from args_schema."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from pydantic import BaseModel
|
||||
|
||||
class MyToolInput(BaseModel):
|
||||
query: str
|
||||
limit: int = 10
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
args_schema: type[BaseModel] = MyToolInput
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
run_params = metadata[0]["run_params_schema"]
|
||||
assert "properties" in run_params
|
||||
assert "query" in run_params["properties"]
|
||||
assert "limit" in run_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_env_vars(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts env_vars."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
env_vars: list[EnvVar] = [
|
||||
EnvVar(name="MY_API_KEY", description="API key for service", required=True),
|
||||
EnvVar(name="MY_OPTIONAL_VAR", description="Optional var", required=False, default="default_value"),
|
||||
]
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
env_vars = metadata[0]["env_vars"]
|
||||
assert len(env_vars) == 2
|
||||
assert env_vars[0]["name"] == "MY_API_KEY"
|
||||
assert env_vars[0]["description"] == "API key for service"
|
||||
assert env_vars[0]["required"] is True
|
||||
assert env_vars[1]["name"] == "MY_OPTIONAL_VAR"
|
||||
assert env_vars[1]["required"] is False
|
||||
assert env_vars[1]["default"] == "default_value"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_env_vars_field_default_factory(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts env_vars declared with Field(default_factory=...)."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
from crewai.tools.base_tool import EnvVar
|
||||
from pydantic import Field
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
env_vars: list[EnvVar] = Field(
|
||||
default_factory=lambda: [
|
||||
EnvVar(name="MY_TOOL_API", description="API token for my tool", required=True),
|
||||
]
|
||||
)
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
env_vars = metadata[0]["env_vars"]
|
||||
assert len(env_vars) == 1
|
||||
assert env_vars[0]["name"] == "MY_TOOL_API"
|
||||
assert env_vars[0]["description"] == "API token for my tool"
|
||||
assert env_vars[0]["required"] is True
|
||||
|
||||
|
||||
def test_extract_tools_metadata_with_custom_init_params(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts init_params_schema with custom params."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
api_endpoint: str = "https://api.example.com"
|
||||
timeout: int = 30
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
init_params = metadata[0]["init_params_schema"]
|
||||
assert "properties" in init_params
|
||||
# Custom params should be included
|
||||
assert "api_endpoint" in init_params["properties"]
|
||||
assert "timeout" in init_params["properties"]
|
||||
# Base params should be filtered out
|
||||
assert "name" not in init_params["properties"]
|
||||
assert "description" not in init_params["properties"]
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_tools(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple tools."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class FirstTool(BaseTool):
|
||||
name: str = "first_tool"
|
||||
description: str = "First test tool"
|
||||
|
||||
class SecondTool(BaseTool):
|
||||
name: str = "second_tool"
|
||||
description: str = "Second test tool"
|
||||
|
||||
__all__ = ['FirstTool', 'SecondTool']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "FirstTool" in names
|
||||
assert "SecondTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_multiple_init_files(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple __init__.py files."""
|
||||
# Create tool in root __init__.py
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class RootTool(BaseTool):
|
||||
name: str = "root_tool"
|
||||
description: str = "Root tool"
|
||||
|
||||
__all__ = ['RootTool']
|
||||
""",
|
||||
)
|
||||
|
||||
# Create nested package with another tool
|
||||
nested_dir = temp_project_dir / "nested"
|
||||
nested_dir.mkdir()
|
||||
create_init_file(
|
||||
nested_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class NestedTool(BaseTool):
|
||||
name: str = "nested_tool"
|
||||
description: str = "Nested tool"
|
||||
|
||||
__all__ = ['NestedTool']
|
||||
""",
|
||||
)
|
||||
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 2
|
||||
names = [m["name"] for m in metadata]
|
||||
assert "RootTool" in names
|
||||
assert "NestedTool" in names
|
||||
|
||||
|
||||
def test_extract_tools_metadata_ignores_non_tool_exports(temp_project_dir):
|
||||
"""Test that extract_tools_metadata ignores non-BaseTool exports."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
name: str = "my_tool"
|
||||
description: str = "A test tool"
|
||||
|
||||
def not_a_tool():
|
||||
pass
|
||||
|
||||
SOME_CONSTANT = "value"
|
||||
|
||||
__all__ = ['MyTool', 'not_a_tool', 'SOME_CONSTANT']
|
||||
""",
|
||||
)
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert len(metadata) == 1
|
||||
assert metadata[0]["name"] == "MyTool"
|
||||
|
||||
|
||||
def test_extract_tools_metadata_import_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on import error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from nonexistent_module import something
|
||||
|
||||
class MyTool(BaseTool):
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
|
||||
def test_extract_tools_metadata_syntax_error_returns_empty(temp_project_dir):
|
||||
"""Test that extract_tools_metadata returns empty list on syntax error."""
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
|
||||
class MyTool(BaseTool):
|
||||
# Missing closing parenthesis
|
||||
def __init__(self, name:
|
||||
pass
|
||||
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
@@ -185,9 +185,14 @@ def test_publish_when_not_in_sync(mock_is_synced, capsys, tool_command):
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
@patch("crewai.cli.tools.main.ToolCommand._print_current_organization")
|
||||
def test_publish_when_not_in_sync_and_force(
|
||||
mock_print_org,
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -222,6 +227,7 @@ def test_publish_when_not_in_sync_and_force(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
mock_print_org.assert_called_once()
|
||||
|
||||
@@ -242,7 +248,12 @@ def test_publish_when_not_in_sync_and_force(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_success(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
@@ -277,6 +288,7 @@ def test_publish_success(
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
|
||||
|
||||
@@ -295,7 +307,12 @@ def test_publish_success(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_failure(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -336,7 +353,12 @@ def test_publish_failure(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
return_value=[{"name": "SampleTool", "humanized_name": "sample_tool", "description": "A sample tool", "run_params_schema": {}, "init_params_schema": {}, "env_vars": []}],
|
||||
)
|
||||
def test_publish_api_error(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
@@ -362,6 +384,63 @@ def test_publish_api_error(
|
||||
mock_publish.assert_called_once()
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.get_project_name", return_value="sample-tool")
|
||||
@patch("crewai.cli.tools.main.get_project_version", return_value="1.0.0")
|
||||
@patch("crewai.cli.tools.main.get_project_description", return_value="A sample tool")
|
||||
@patch("crewai.cli.tools.main.subprocess.run")
|
||||
@patch("crewai.cli.tools.main.os.listdir", return_value=["sample-tool-1.0.0.tar.gz"])
|
||||
@patch(
|
||||
"crewai.cli.tools.main.open",
|
||||
new_callable=unittest.mock.mock_open,
|
||||
read_data=b"sample tarball content",
|
||||
)
|
||||
@patch("crewai.cli.plus_api.PlusAPI.publish_tool")
|
||||
@patch("crewai.cli.tools.main.git.Repository.is_synced", return_value=True)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_available_exports",
|
||||
return_value=[{"name": "SampleTool"}],
|
||||
)
|
||||
@patch(
|
||||
"crewai.cli.tools.main.extract_tools_metadata",
|
||||
side_effect=Exception("Failed to extract metadata"),
|
||||
)
|
||||
def test_publish_metadata_extraction_failure_continues_with_warning(
|
||||
mock_tools_metadata,
|
||||
mock_available_exports,
|
||||
mock_is_synced,
|
||||
mock_publish,
|
||||
mock_open,
|
||||
mock_listdir,
|
||||
mock_subprocess_run,
|
||||
mock_get_project_description,
|
||||
mock_get_project_version,
|
||||
mock_get_project_name,
|
||||
capsys,
|
||||
tool_command,
|
||||
):
|
||||
"""Test that metadata extraction failure shows warning but continues publishing."""
|
||||
mock_publish_response = MagicMock()
|
||||
mock_publish_response.status_code = 200
|
||||
mock_publish_response.json.return_value = {"handle": "sample-tool"}
|
||||
mock_publish.return_value = mock_publish_response
|
||||
|
||||
tool_command.publish(is_public=True)
|
||||
|
||||
output = capsys.readouterr().out
|
||||
assert "Warning: Could not extract tool metadata" in output
|
||||
assert "Publishing will continue without detailed metadata" in output
|
||||
assert "No tool metadata extracted" in output
|
||||
mock_publish.assert_called_once_with(
|
||||
handle="sample-tool",
|
||||
is_public=True,
|
||||
version="1.0.0",
|
||||
description="A sample tool",
|
||||
encoded_file=unittest.mock.ANY,
|
||||
available_exports=[{"name": "SampleTool"}],
|
||||
tools_metadata=[],
|
||||
)
|
||||
|
||||
|
||||
@patch("crewai.cli.tools.main.Settings")
|
||||
def test_print_current_organization_with_org(mock_settings, capsys, tool_command):
|
||||
mock_settings_instance = MagicMock()
|
||||
|
||||
@@ -132,12 +132,12 @@ def test_embedding_configuration_flow(
|
||||
|
||||
embedder_config = {
|
||||
"provider": "sentence-transformer",
|
||||
"model_name": "all-MiniLM-L6-v2",
|
||||
"config": {"model_name": "all-MiniLM-L6-v2"},
|
||||
}
|
||||
|
||||
KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
storage = KnowledgeStorage(embedder=embedder_config, collection_name="embedding_test")
|
||||
|
||||
mock_get_embedding.assert_called_once_with(embedder_config)
|
||||
mock_get_embedding.assert_called_once_with(storage.embedder)
|
||||
|
||||
|
||||
@patch("crewai.knowledge.storage.knowledge_storage.get_rag_client")
|
||||
|
||||
@@ -125,8 +125,8 @@ def test_anthropic_specific_parameters():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
assert llm.stop_sequences == ["Human:", "Assistant:"]
|
||||
assert llm.stream == True
|
||||
assert llm.client.max_retries == 5
|
||||
assert llm.client.timeout == 60
|
||||
assert llm._client.max_retries == 5
|
||||
assert llm._client.timeout == 60
|
||||
|
||||
|
||||
def test_anthropic_completion_call():
|
||||
@@ -563,8 +563,8 @@ def test_anthropic_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"ANTHROPIC_API_KEY": "test-anthropic-key"}):
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'messages')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'messages')
|
||||
|
||||
|
||||
def test_anthropic_token_usage_tracking():
|
||||
@@ -574,7 +574,7 @@ def test_anthropic_token_usage_tracking():
|
||||
llm = LLM(model="anthropic/claude-3-5-sonnet-20241022")
|
||||
|
||||
# Mock the Anthropic response with usage information
|
||||
with patch.object(llm.client.messages, 'create') as mock_create:
|
||||
with patch.object(llm._client.messages, 'create') as mock_create:
|
||||
mock_response = MagicMock()
|
||||
mock_response.content = [MagicMock(text="test response")]
|
||||
mock_response.usage = MagicMock(input_tokens=50, output_tokens=25)
|
||||
@@ -639,14 +639,14 @@ def test_anthropic_thinking():
|
||||
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
original_create = llm.client.messages.create
|
||||
original_create = llm._client.messages.create
|
||||
captured_params = {}
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_params.update(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
result = llm.call("What is the weather in Tokyo?")
|
||||
|
||||
assert result is not None
|
||||
@@ -677,14 +677,14 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert isinstance(llm, AnthropicCompletion)
|
||||
|
||||
# Capture all messages.create calls to verify thinking blocks are included
|
||||
original_create = llm.client.messages.create
|
||||
original_create = llm._client.messages.create
|
||||
captured_calls = []
|
||||
|
||||
def capture_and_call(**kwargs):
|
||||
captured_calls.append(kwargs)
|
||||
return original_create(**kwargs)
|
||||
|
||||
with patch.object(llm.client.messages, 'create', side_effect=capture_and_call):
|
||||
with patch.object(llm._client.messages, 'create', side_effect=capture_and_call):
|
||||
# First call - establishes context and generates thinking blocks
|
||||
messages = [{"role": "user", "content": "What is 2+2?"}]
|
||||
first_result = llm.call(messages)
|
||||
@@ -695,8 +695,8 @@ def test_anthropic_thinking_blocks_preserved_across_turns():
|
||||
assert len(first_result) > 0
|
||||
|
||||
# Verify thinking blocks were stored after first response
|
||||
assert len(llm.previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm.previous_thinking_blocks[0]
|
||||
assert len(llm._previous_thinking_blocks) > 0, "No thinking blocks stored after first call"
|
||||
first_thinking = llm._previous_thinking_blocks[0]
|
||||
assert first_thinking["type"] == "thinking"
|
||||
assert "thinking" in first_thinking
|
||||
assert "signature" in first_thinking
|
||||
|
||||
@@ -66,7 +66,7 @@ def test_azure_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Azure client responses
|
||||
with patch.object(completion.client, 'complete') as mock_complete:
|
||||
with patch.object(completion._client, 'complete') as mock_complete:
|
||||
# Mock tool call in response with proper type
|
||||
mock_tool_call = MagicMock(spec=ChatCompletionsToolCall)
|
||||
mock_tool_call.function.name = "get_weather"
|
||||
@@ -698,7 +698,7 @@ def test_azure_environment_variable_endpoint():
|
||||
}):
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
assert llm.client is not None
|
||||
assert llm._client is not None
|
||||
assert llm.endpoint == "https://test.openai.azure.com/openai/deployments/gpt-4"
|
||||
|
||||
|
||||
@@ -709,7 +709,7 @@ def test_azure_token_usage_tracking():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock the Azure response with usage information
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "test response"
|
||||
mock_message.tool_calls = None
|
||||
@@ -747,7 +747,7 @@ def test_azure_http_error_handling():
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
# Mock an HTTP error
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_complete.side_effect = HttpResponseError(message="Rate limit exceeded", response=MagicMock(status_code=429))
|
||||
|
||||
with pytest.raises(HttpResponseError):
|
||||
@@ -966,7 +966,7 @@ def test_azure_improved_error_messages():
|
||||
|
||||
llm = LLM(model="azure/gpt-4")
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
error_401 = HttpResponseError(message="Unauthorized")
|
||||
error_401.status_code = 401
|
||||
mock_complete.side_effect = error_401
|
||||
@@ -1327,7 +1327,7 @@ def test_azure_stop_words_not_applied_to_structured_output():
|
||||
# Without the fix, this would be truncated at "Observation:" breaking the JSON
|
||||
json_response = '{"finding": "The data shows growth", "observation": "Observation: This confirms the hypothesis"}'
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = json_response
|
||||
mock_message.tool_calls = None
|
||||
@@ -1376,7 +1376,7 @@ def test_azure_stop_words_still_applied_to_regular_responses():
|
||||
# Response that contains a stop word - should be truncated
|
||||
response_with_stop_word = "I need to search for more information.\n\nAction: search\nObservation: Found results"
|
||||
|
||||
with patch.object(llm.client, 'complete') as mock_complete:
|
||||
with patch.object(llm._client, 'complete') as mock_complete:
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = response_with_stop_word
|
||||
mock_message.tool_calls = None
|
||||
|
||||
@@ -674,7 +674,7 @@ def test_bedrock_token_usage_tracking():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Mock the Bedrock response with usage information
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
@@ -719,7 +719,7 @@ def test_bedrock_tool_use_conversation_flow():
|
||||
available_functions = {"get_weather": mock_weather_tool}
|
||||
|
||||
# Mock the Bedrock client responses
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
# First response: tool use request
|
||||
tool_use_response = {
|
||||
'output': {
|
||||
@@ -805,7 +805,7 @@ def test_bedrock_client_error_handling():
|
||||
llm = LLM(model="bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0")
|
||||
|
||||
# Test ValidationException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ValidationException',
|
||||
@@ -819,7 +819,7 @@ def test_bedrock_client_error_handling():
|
||||
assert "validation" in str(exc_info.value).lower()
|
||||
|
||||
# Test ThrottlingException
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
error_response = {
|
||||
'Error': {
|
||||
'Code': 'ThrottlingException',
|
||||
@@ -861,7 +861,7 @@ def test_bedrock_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client, 'converse') as mock_converse:
|
||||
with patch.object(llm._client, 'converse') as mock_converse:
|
||||
mock_response = {
|
||||
'output': {
|
||||
'message': {
|
||||
|
||||
@@ -556,8 +556,8 @@ def test_gemini_environment_variable_api_key():
|
||||
with patch.dict(os.environ, {"GOOGLE_API_KEY": "test-google-key"}):
|
||||
llm = LLM(model="google/gemini-2.0-flash-001")
|
||||
|
||||
assert llm.client is not None
|
||||
assert hasattr(llm.client, 'models')
|
||||
assert llm._client is not None
|
||||
assert hasattr(llm._client, 'models')
|
||||
assert llm.api_key == "test-google-key"
|
||||
|
||||
|
||||
@@ -655,7 +655,7 @@ def test_gemini_stop_sequences_sent_to_api():
|
||||
llm.stop = ["\nObservation:", "\nThought:"]
|
||||
|
||||
# Patch the API call to capture parameters without making real call
|
||||
with patch.object(llm.client.models, 'generate_content') as mock_generate:
|
||||
with patch.object(llm._client.models, 'generate_content') as mock_generate:
|
||||
mock_response = MagicMock()
|
||||
mock_response.text = "Hello"
|
||||
mock_response.candidates = []
|
||||
|
||||
@@ -371,11 +371,11 @@ def test_openai_client_setup_with_extra_arguments():
|
||||
assert llm.top_p == 0.5
|
||||
|
||||
# Check that client parameters are properly configured
|
||||
assert llm.client.max_retries == 3
|
||||
assert llm.client.timeout == 30
|
||||
assert llm._client.max_retries == 3
|
||||
assert llm._client.timeout == 30
|
||||
|
||||
# Test that parameters are properly used in API calls
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -396,7 +396,7 @@ def test_extra_arguments_are_passed_to_openai_completion():
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", temperature=0.7, max_tokens=1000, top_p=0.5, max_retries=3)
|
||||
|
||||
with patch.object(llm.client.chat.completions, 'create') as mock_create:
|
||||
with patch.object(llm._client.chat.completions, 'create') as mock_create:
|
||||
mock_create.return_value = MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="test response", tool_calls=None))],
|
||||
usage=MagicMock(prompt_tokens=10, completion_tokens=20, total_tokens=30)
|
||||
@@ -507,7 +507,7 @@ def test_openai_streaming_with_response_model():
|
||||
|
||||
llm = LLM(model="openai/gpt-4o", stream=True)
|
||||
|
||||
with patch.object(llm.client.beta.chat.completions, "stream") as mock_stream:
|
||||
with patch.object(llm._client.beta.chat.completions, "stream") as mock_stream:
|
||||
# Create mock chunks with content.delta event structure
|
||||
mock_chunk1 = MagicMock()
|
||||
mock_chunk1.type = "content.delta"
|
||||
@@ -1523,6 +1523,69 @@ def test_openai_stop_words_not_applied_to_structured_output():
|
||||
assert "Observation:" in result.observation
|
||||
|
||||
|
||||
def test_openai_gpt5_models_do_not_support_stop_words():
|
||||
"""
|
||||
Test that GPT-5 family models do not support stop words via the API.
|
||||
GPT-5 models reject the 'stop' parameter, so stop words must be
|
||||
applied client-side only.
|
||||
"""
|
||||
gpt5_models = [
|
||||
"gpt-5",
|
||||
"gpt-5-mini",
|
||||
"gpt-5-nano",
|
||||
"gpt-5-pro",
|
||||
"gpt-5.1",
|
||||
"gpt-5.1-chat",
|
||||
"gpt-5.2",
|
||||
"gpt-5.2-chat",
|
||||
]
|
||||
|
||||
for model_name in gpt5_models:
|
||||
llm = OpenAICompletion(model=model_name)
|
||||
assert llm.supports_stop_words() == False, (
|
||||
f"Expected {model_name} to NOT support stop words"
|
||||
)
|
||||
|
||||
|
||||
def test_openai_non_gpt5_models_support_stop_words():
|
||||
"""
|
||||
Test that non-GPT-5 models still support stop words normally.
|
||||
"""
|
||||
supported_models = [
|
||||
"gpt-4o",
|
||||
"gpt-4o-mini",
|
||||
"gpt-4.1",
|
||||
"gpt-4.1-mini",
|
||||
"gpt-4-turbo",
|
||||
]
|
||||
|
||||
for model_name in supported_models:
|
||||
llm = OpenAICompletion(model=model_name)
|
||||
assert llm.supports_stop_words() == True, (
|
||||
f"Expected {model_name} to support stop words"
|
||||
)
|
||||
|
||||
|
||||
def test_openai_gpt5_still_applies_stop_words_client_side():
|
||||
"""
|
||||
Test that GPT-5 models still truncate responses at stop words client-side
|
||||
via _apply_stop_words(), even though they don't send 'stop' to the API.
|
||||
"""
|
||||
llm = OpenAICompletion(
|
||||
model="gpt-5.2",
|
||||
stop=["Observation:", "Final Answer:"],
|
||||
)
|
||||
|
||||
assert llm.supports_stop_words() == False
|
||||
|
||||
response = "I need to search.\n\nAction: search\nObservation: Found results"
|
||||
result = llm._apply_stop_words(response)
|
||||
|
||||
assert "Observation:" not in result
|
||||
assert "Found results" not in result
|
||||
assert "I need to search" in result
|
||||
|
||||
|
||||
def test_openai_stop_words_still_applied_to_regular_responses():
|
||||
"""
|
||||
Test that stop words ARE still applied for regular (non-structured) responses.
|
||||
@@ -1767,7 +1830,7 @@ def test_openai_responses_api_cached_prompt_tokens_with_tools():
|
||||
}
|
||||
]
|
||||
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='response')
|
||||
llm = OpenAICompletion(model="gpt-4.1", api='responses')
|
||||
|
||||
# First call with tool
|
||||
llm.call(
|
||||
@@ -1843,7 +1906,7 @@ def test_openai_streaming_returns_tool_calls_without_available_functions():
|
||||
mock_chunk_3.id = "chatcmpl-1"
|
||||
|
||||
with patch.object(
|
||||
llm.client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
llm._client.chat.completions, "create", return_value=iter([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
):
|
||||
result = llm.call(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
@@ -1934,7 +1997,7 @@ async def test_openai_async_streaming_returns_tool_calls_without_available_funct
|
||||
return MockAsyncStream([mock_chunk_1, mock_chunk_2, mock_chunk_3])
|
||||
|
||||
with patch.object(
|
||||
llm.async_client.chat.completions, "create", side_effect=mock_create
|
||||
llm._async_client.chat.completions, "create", side_effect=mock_create
|
||||
):
|
||||
result = await llm.acall(
|
||||
messages=[{"role": "user", "content": "Calculate 1+1"}],
|
||||
|
||||
@@ -3,6 +3,8 @@
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import ValidationError
|
||||
|
||||
from crewai.knowledge.storage.knowledge_storage import ( # type: ignore[import-untyped]
|
||||
KnowledgeStorage,
|
||||
)
|
||||
@@ -59,7 +61,7 @@ def test_knowledge_storage_invalid_embedding_config(mock_get_client: MagicMock)
|
||||
"Unsupported provider: invalid_provider"
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unsupported provider: invalid_provider"):
|
||||
with pytest.raises(ValidationError):
|
||||
KnowledgeStorage(
|
||||
embedder={"provider": "invalid_provider"},
|
||||
collection_name="invalid_embedding_test",
|
||||
|
||||
@@ -682,6 +682,126 @@ def test_llm_call_when_stop_is_unsupported_when_additional_drop_params_is_provid
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_litellm_gpt5_call_succeeds_without_stop_error():
|
||||
"""
|
||||
Integration test: GPT-5 call succeeds when stop words are configured,
|
||||
because stop is omitted from API params and applied client-side.
|
||||
"""
|
||||
llm = LLM(model="gpt-5", stop=["Observation:"], is_litellm=True)
|
||||
result = llm.call("What is the capital of France?")
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
|
||||
|
||||
def test_litellm_gpt5_does_not_send_stop_in_params():
|
||||
"""
|
||||
Test that the LiteLLM fallback path does not include 'stop' in API params
|
||||
for GPT-5.x models, since they reject it at the API level.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
params = llm._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert params.get("stop") is None, (
|
||||
"GPT-5.x models should not have 'stop' in API params"
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_non_gpt5_sends_stop_in_params():
|
||||
"""
|
||||
Test that the LiteLLM fallback path still includes 'stop' in API params
|
||||
for models that support it.
|
||||
"""
|
||||
llm = LLM(model="gpt-4o", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
params = llm._prepare_completion_params(
|
||||
messages=[{"role": "user", "content": "Hello"}]
|
||||
)
|
||||
|
||||
assert params.get("stop") == ["Observation:"], (
|
||||
"Non-GPT-5 models should have 'stop' in API params"
|
||||
)
|
||||
|
||||
|
||||
def test_litellm_retry_catches_litellm_unsupported_params_error(caplog):
|
||||
"""
|
||||
Test that the retry logic catches LiteLLM's UnsupportedParamsError format
|
||||
("does not support parameters") in addition to the OpenAI API format.
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
litellm_error = Exception(
|
||||
"litellm.UnsupportedParamsError: openai does not support parameters: "
|
||||
"['stop'], for model=openai/gpt-5.2."
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
try:
|
||||
import litellm
|
||||
except ImportError:
|
||||
pytest.skip("litellm is not installed; skipping LiteLLM retry test")
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise litellm_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
with caplog.at_level(logging.INFO):
|
||||
result = llm.call("What is the capital of France?")
|
||||
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
assert "stop" in llm.additional_params.get("additional_drop_params", [])
|
||||
|
||||
|
||||
def test_litellm_retry_catches_openai_api_stop_error(caplog):
|
||||
"""
|
||||
Test that the retry logic still catches the OpenAI API error format
|
||||
("Unsupported parameter: 'stop'").
|
||||
"""
|
||||
llm = LLM(model="openai/gpt-5.2", stop=["Observation:"], is_litellm=True)
|
||||
|
||||
api_error = Exception(
|
||||
"Unsupported parameter: 'stop' is not supported with this model."
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
|
||||
def mock_completion(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
raise api_error
|
||||
return MagicMock(
|
||||
choices=[MagicMock(message=MagicMock(content="Paris", tool_calls=None))],
|
||||
usage=MagicMock(
|
||||
prompt_tokens=10,
|
||||
completion_tokens=5,
|
||||
total_tokens=15,
|
||||
),
|
||||
)
|
||||
|
||||
with patch("litellm.completion", side_effect=mock_completion):
|
||||
with caplog.at_level(logging.INFO):
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
assert "Retrying LLM call without the unsupported 'stop'" in caplog.text
|
||||
assert "stop" in llm.additional_params.get("additional_drop_params", [])
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def ollama_llm():
|
||||
return LLM(model="ollama/llama3.2:3b", is_litellm=True)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from typing import Any, ClassVar
|
||||
from unittest.mock import Mock, patch
|
||||
from unittest.mock import Mock, create_autospec, patch
|
||||
|
||||
import pytest
|
||||
from crewai.agent import Agent
|
||||
@@ -372,8 +372,11 @@ def test_internal_crew_with_mcp():
|
||||
mock_adapter = Mock()
|
||||
mock_adapter.tools = ToolCollection([simple_tool, another_simple_tool])
|
||||
|
||||
mock_llm = Mock()
|
||||
mock_llm.__class__ = BaseLLM
|
||||
class _StubLLM(BaseLLM):
|
||||
def call(self, *a: Any, **kw: Any) -> str:
|
||||
return ""
|
||||
|
||||
mock_llm = create_autospec(_StubLLM(model="stub"), instance=True)
|
||||
|
||||
with (
|
||||
patch("crewai_tools.MCPServerAdapter", return_value=mock_adapter) as adapter_mock,
|
||||
|
||||
Reference in New Issue
Block a user