mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
1 Commits
devin/1764
...
lorenze/im
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
79f716162b |
@@ -316,6 +316,143 @@ class LLM(BaseLLM):
|
||||
stream: bool = False,
|
||||
**kwargs,
|
||||
):
|
||||
# Check for provider prefixes and route to native implementations
|
||||
if "/" in model:
|
||||
provider, actual_model = model.split("/", 1)
|
||||
|
||||
# Route to OpenAI native implementation
|
||||
if provider.lower() == "openai":
|
||||
try:
|
||||
from crewai.llms.openai import OpenAILLM
|
||||
|
||||
# Create native OpenAI instance with all the same parameters
|
||||
native_llm = OpenAILLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native OpenAI implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Route to Claude native implementation
|
||||
elif provider.lower() == "anthropic":
|
||||
try:
|
||||
from crewai.llms.anthropic import ClaudeLLM
|
||||
|
||||
# Create native Claude instance with all the same parameters
|
||||
native_llm = ClaudeLLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native Claude implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Route to Gemini native implementation
|
||||
elif provider.lower() == "google":
|
||||
try:
|
||||
from crewai.llms.google import GeminiLLM
|
||||
|
||||
# Create native Gemini instance with all the same parameters
|
||||
native_llm = GeminiLLM(
|
||||
model=actual_model,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
n=n,
|
||||
stop=stop,
|
||||
max_completion_tokens=max_completion_tokens,
|
||||
max_tokens=max_tokens,
|
||||
presence_penalty=presence_penalty,
|
||||
frequency_penalty=frequency_penalty,
|
||||
logit_bias=logit_bias,
|
||||
response_format=response_format,
|
||||
seed=seed,
|
||||
logprobs=logprobs,
|
||||
top_logprobs=top_logprobs,
|
||||
base_url=base_url,
|
||||
api_base=api_base,
|
||||
api_version=api_version,
|
||||
api_key=api_key,
|
||||
callbacks=callbacks,
|
||||
reasoning_effort=reasoning_effort,
|
||||
stream=stream,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Replace this LLM instance with the native one
|
||||
self.__class__ = native_llm.__class__
|
||||
self.__dict__.update(native_llm.__dict__)
|
||||
return
|
||||
|
||||
except ImportError:
|
||||
# Fall back to LiteLLM if native implementation unavailable
|
||||
print(
|
||||
f"Native Gemini implementation not available, using LiteLLM for {model}"
|
||||
)
|
||||
model = actual_model # Remove the prefix for LiteLLM
|
||||
|
||||
# Continue with original LiteLLM initialization
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
@@ -1139,7 +1276,11 @@ class LLM(BaseLLM):
|
||||
|
||||
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
|
||||
# Ollama doesn't supports last message to be 'assistant'
|
||||
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
|
||||
if (
|
||||
"ollama" in self.model.lower()
|
||||
and messages
|
||||
and messages[-1]["role"] == "assistant"
|
||||
):
|
||||
return messages + [{"role": "user", "content": ""}]
|
||||
|
||||
# Handle Anthropic models
|
||||
|
||||
@@ -1 +1,11 @@
|
||||
"""LLM implementations for crewAI."""
|
||||
"""CrewAI LLM implementations."""
|
||||
|
||||
from .base_llm import BaseLLM
|
||||
from .openai import OpenAILLM
|
||||
from .anthropic import ClaudeLLM
|
||||
from .google import GeminiLLM
|
||||
|
||||
# Import the main LLM class for backward compatibility
|
||||
|
||||
|
||||
__all__ = ["BaseLLM", "OpenAILLM", "ClaudeLLM", "GeminiLLM"]
|
||||
|
||||
5
src/crewai/llms/anthropic/__init__.py
Normal file
5
src/crewai/llms/anthropic/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Anthropic Claude LLM implementation for CrewAI."""
|
||||
|
||||
from .claude import ClaudeLLM
|
||||
|
||||
__all__ = ["ClaudeLLM"]
|
||||
569
src/crewai/llms/anthropic/claude.py
Normal file
569
src/crewai/llms/anthropic/claude.py
Normal file
@@ -0,0 +1,569 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal
|
||||
from anthropic import Anthropic
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class ClaudeLLM(BaseLLM):
|
||||
"""Anthropic Claude LLM implementation with full LLM class compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "claude-3-5-sonnet-20241022",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None, # Not supported by Claude but kept for compatibility
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
frequency_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
logit_bias: Optional[
|
||||
Dict[int, float]
|
||||
] = None, # Not supported but kept for compatibility
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None, # Not supported but kept for compatibility
|
||||
logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None, # Not used by Anthropic
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[
|
||||
Literal["none", "low", "medium", "high"]
|
||||
] = None, # Not used by Claude
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
# Claude-specific parameters
|
||||
thinking_mode: bool = False, # Enable Claude's thinking mode
|
||||
top_k: Optional[int] = None, # Claude-specific sampling parameter
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Claude LLM with full compatibility.
|
||||
|
||||
Args:
|
||||
model: Claude model name (e.g., 'claude-3-5-sonnet-20241022')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-1 for Claude)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions (not supported by Claude, kept for compatibility)
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Not supported by Claude, kept for compatibility
|
||||
frequency_penalty: Not supported by Claude, kept for compatibility
|
||||
logit_bias: Not supported by Claude, kept for compatibility
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Not supported by Claude, kept for compatibility
|
||||
logprobs: Not supported by Claude, kept for compatibility
|
||||
top_logprobs: Not supported by Claude, kept for compatibility
|
||||
base_url: Custom API base URL
|
||||
api_base: Legacy API base parameter
|
||||
api_version: Not used by Anthropic
|
||||
api_key: Anthropic API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Not used by Claude, kept for compatibility
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
thinking_mode: Enable Claude's thinking mode (if supported)
|
||||
top_k: Claude-specific top-k sampling parameter
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n # Claude doesn't support n>1, but we store it for compatibility
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base or base_url
|
||||
self.base_url = base_url or api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
|
||||
# Claude-specific parameters
|
||||
self.thinking_mode = thinking_mode
|
||||
self.top_k = top_k
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize Anthropic client
|
||||
client_kwargs = {}
|
||||
if self.api_key:
|
||||
client_kwargs["api_key"] = self.api_key
|
||||
if self.base_url:
|
||||
client_kwargs["base_url"] = self.base_url
|
||||
if self.timeout:
|
||||
client_kwargs["timeout"] = self.timeout
|
||||
if max_retries:
|
||||
client_kwargs["max_retries"] = max_retries
|
||||
|
||||
# Add any additional kwargs that might be relevant to the client
|
||||
for key, value in kwargs.items():
|
||||
if key not in ["thinking_mode", "top_k"]: # Exclude our custom params
|
||||
client_kwargs[key] = value
|
||||
|
||||
self.client = Anthropic(**client_kwargs)
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration for Claude models."""
|
||||
# Claude model configurations based on Anthropic's documentation
|
||||
model_configs = {
|
||||
# Claude 3.5 Sonnet
|
||||
"claude-3-5-sonnet-20241022": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"claude-3-5-sonnet-20240620": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3.5 Haiku
|
||||
"claude-3-5-haiku-20241022": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Opus
|
||||
"claude-3-opus-20240229": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Sonnet
|
||||
"claude-3-sonnet-20240229": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 3 Haiku
|
||||
"claude-3-haiku-20240307": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Claude 2.1
|
||||
"claude-2.1": {
|
||||
"context_window": 200000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
"claude-2": {
|
||||
"context_window": 100000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
# Claude Instant
|
||||
"claude-instant-1.2": {
|
||||
"context_window": 100000,
|
||||
"supports_tools": False,
|
||||
"supports_vision": False,
|
||||
},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {
|
||||
"context_window": 200000,
|
||||
"supports_tools": True,
|
||||
"supports_vision": False,
|
||||
}
|
||||
|
||||
# Try exact match first
|
||||
if self.model in model_configs:
|
||||
return model_configs[self.model]
|
||||
|
||||
# Try prefix match for versioned models
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages for Anthropic API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted message dicts
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Claude requires alternating user/assistant messages and cannot start with system
|
||||
formatted_messages = []
|
||||
system_message = None
|
||||
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
# Store system message separately - Claude handles it differently
|
||||
if system_message is None:
|
||||
system_message = msg["content"]
|
||||
else:
|
||||
system_message += "\n\n" + msg["content"]
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
|
||||
# Ensure messages alternate and start with user
|
||||
if formatted_messages and formatted_messages[0]["role"] != "user":
|
||||
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
|
||||
|
||||
# Store system message for later use
|
||||
self._system_message = system_message
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
|
||||
"""Format tools for Claude function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Claude-formatted tool definitions
|
||||
"""
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to Claude tool format
|
||||
formatted_tool = {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"input_schema": tool.get("parameters", {}),
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from Claude response.
|
||||
|
||||
Args:
|
||||
response: Claude API response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
# Claude returns tool use in content blocks
|
||||
if not hasattr(response, "content") or not available_functions:
|
||||
return response.content[0].text if response.content else ""
|
||||
|
||||
# Look for tool use blocks
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "type") and content_block.type == "tool_use":
|
||||
function_name = content_block.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Claude provides arguments as a dict
|
||||
function_args = content_block.input
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
# If no tool calls, return text content
|
||||
return response.content[0].text if response.content else ""
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call Claude API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native claude", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
system_message = getattr(self, "_system_message", None)
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
"max_tokens": self.max_tokens or 4000, # Claude requires max_tokens
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
api_params["system"] = system_message
|
||||
|
||||
# Add optional parameters that Claude supports
|
||||
if self.temperature is not None:
|
||||
api_params["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
api_params["top_p"] = self.top_p
|
||||
|
||||
if self.top_k is not None:
|
||||
api_params["top_k"] = self.top_k
|
||||
|
||||
if self.stop:
|
||||
api_params["stop_sequences"] = self.stop
|
||||
|
||||
# Add tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
if formatted_tools:
|
||||
api_params["tools"] = formatted_tools
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
response = self.client.messages.create(stream=True, **api_params)
|
||||
# Handle streaming (simplified implementation)
|
||||
full_response = ""
|
||||
try:
|
||||
for event in response:
|
||||
if hasattr(event, "type"):
|
||||
if event.type == "content_block_delta":
|
||||
if hasattr(event, "delta") and hasattr(
|
||||
event.delta, "text"
|
||||
):
|
||||
full_response += event.delta.text
|
||||
except Exception as e:
|
||||
# If streaming fails, fall back to the response we have
|
||||
print(f"Streaming error (continuing with partial response): {e}")
|
||||
result = full_response or "No response content"
|
||||
else:
|
||||
response = self.client.messages.create(**api_params)
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": formatted_messages,
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
failed_event_data = {
|
||||
"error": str(e),
|
||||
}
|
||||
if from_task is not None:
|
||||
failed_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
failed_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(**failed_event_data),
|
||||
)
|
||||
|
||||
raise RuntimeError(f"Claude API call failed: {str(e)}") from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if Claude models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 200000)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
|
||||
def supports_vision(self) -> bool:
|
||||
"""Check if the current model supports vision capabilities."""
|
||||
return self.model_config.get("supports_vision", False)
|
||||
5
src/crewai/llms/google/__init__.py
Normal file
5
src/crewai/llms/google/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""Google Gemini LLM implementation for CrewAI."""
|
||||
|
||||
from .gemini import GeminiLLM
|
||||
|
||||
__all__ = ["GeminiLLM"]
|
||||
737
src/crewai/llms/google/gemini.py
Normal file
737
src/crewai/llms/google/gemini.py
Normal file
@@ -0,0 +1,737 @@
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal, TYPE_CHECKING
|
||||
from pydantic import BaseModel
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
|
||||
try:
|
||||
from google import genai
|
||||
from google.genai import types
|
||||
except ImportError:
|
||||
genai = None
|
||||
types = None
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class GeminiLLM(BaseLLM):
|
||||
"""Google Gemini LLM implementation using the official Google Gen AI Python SDK."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gemini-1.5-pro",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None, # Not supported by Gemini but kept for compatibility
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
frequency_penalty: Optional[
|
||||
float
|
||||
] = None, # Not supported but kept for compatibility
|
||||
logit_bias: Optional[
|
||||
Dict[int, float]
|
||||
] = None, # Not supported but kept for compatibility
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None, # Not supported but kept for compatibility
|
||||
logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
|
||||
base_url: Optional[str] = None, # Not used by Gemini
|
||||
api_base: Optional[str] = None, # Not used by Gemini
|
||||
api_version: Optional[str] = None, # Not used by Gemini
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[
|
||||
Literal["none", "low", "medium", "high"]
|
||||
] = None, # Not used by Gemini
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
# Gemini-specific parameters
|
||||
top_k: Optional[int] = None, # Gemini top-k sampling parameter
|
||||
candidate_count: int = 1, # Number of response candidates
|
||||
safety_settings: Optional[
|
||||
List[Dict[str, Any]]
|
||||
] = None, # Gemini safety settings
|
||||
generation_config: Optional[
|
||||
Dict[str, Any]
|
||||
] = None, # Additional generation config
|
||||
# Vertex AI parameters
|
||||
use_vertex_ai: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
location: str = "us-central1",
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Gemini LLM with the official Google Gen AI SDK.
|
||||
|
||||
Args:
|
||||
model: Gemini model name (e.g., 'gemini-1.5-pro', 'gemini-2.0-flash-001')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-2 for Gemini)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions (not supported by Gemini, kept for compatibility)
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Not supported by Gemini, kept for compatibility
|
||||
frequency_penalty: Not supported by Gemini, kept for compatibility
|
||||
logit_bias: Not supported by Gemini, kept for compatibility
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Not supported by Gemini, kept for compatibility
|
||||
logprobs: Not supported by Gemini, kept for compatibility
|
||||
top_logprobs: Not supported by Gemini, kept for compatibility
|
||||
base_url: Not used by Gemini
|
||||
api_base: Not used by Gemini
|
||||
api_version: Not used by Gemini
|
||||
api_key: Google AI API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Not used by Gemini, kept for compatibility
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
top_k: Gemini-specific top-k sampling parameter
|
||||
candidate_count: Number of response candidates to generate
|
||||
safety_settings: Gemini safety settings configuration
|
||||
generation_config: Additional Gemini generation configuration
|
||||
use_vertex_ai: Whether to use Vertex AI instead of Gemini API
|
||||
project_id: Google Cloud project ID (required for Vertex AI)
|
||||
location: Google Cloud region (default: us-central1)
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
# Check if Google Gen AI SDK is available
|
||||
if genai is None or types is None:
|
||||
raise ImportError(
|
||||
"Google Gen AI Python SDK is required. Please install it with: "
|
||||
"pip install google-genai"
|
||||
)
|
||||
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base
|
||||
self.base_url = base_url
|
||||
self.api_version = api_version
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
self.max_retries = max_retries
|
||||
|
||||
# Gemini-specific parameters
|
||||
self.top_k = top_k
|
||||
self.candidate_count = candidate_count
|
||||
self.safety_settings = safety_settings or []
|
||||
self.generation_config = generation_config or {}
|
||||
|
||||
# Vertex AI parameters
|
||||
self.use_vertex_ai = use_vertex_ai
|
||||
self.project_id = project_id or os.getenv("GOOGLE_CLOUD_PROJECT")
|
||||
self.location = location
|
||||
|
||||
# API key handling
|
||||
self.api_key = (
|
||||
api_key
|
||||
or os.getenv("GOOGLE_AI_API_KEY")
|
||||
or os.getenv("GEMINI_API_KEY")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
)
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize client attribute
|
||||
self.client: Any = None
|
||||
|
||||
# Initialize the Google Gen AI client
|
||||
self._initialize_client()
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _initialize_client(self):
|
||||
"""Initialize the Google Gen AI client."""
|
||||
if genai is None or types is None:
|
||||
return
|
||||
|
||||
try:
|
||||
if self.use_vertex_ai:
|
||||
if not self.project_id:
|
||||
raise ValueError(
|
||||
"project_id is required when use_vertex_ai=True. "
|
||||
"Set it directly or via GOOGLE_CLOUD_PROJECT environment variable."
|
||||
)
|
||||
self.client = genai.Client(
|
||||
vertexai=True,
|
||||
project=self.project_id,
|
||||
location=self.location,
|
||||
)
|
||||
else:
|
||||
if not self.api_key:
|
||||
raise ValueError(
|
||||
"API key is required for Gemini Developer API. "
|
||||
"Set it via api_key parameter or GOOGLE_AI_API_KEY/GEMINI_API_KEY environment variable."
|
||||
)
|
||||
self.client = genai.Client(api_key=self.api_key)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"Failed to initialize Google Gen AI client: {str(e)}"
|
||||
) from e
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration for Gemini models."""
|
||||
# Gemini model configurations based on Google's documentation
|
||||
model_configs = {
|
||||
# Gemini 2.0 Flash (latest)
|
||||
"gemini-2.0-flash": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.0-flash-001": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-2.0-flash-exp": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini 1.5 Pro
|
||||
"gemini-1.5-pro": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-002": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-001": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-pro-exp-0827": {
|
||||
"context_window": 2097152,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini 1.5 Flash
|
||||
"gemini-1.5-flash": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-002": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-001": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-8b": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
"gemini-1.5-flash-8b-exp-0827": {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Legacy Gemini Pro
|
||||
"gemini-pro": {
|
||||
"context_window": 30720,
|
||||
"supports_tools": True,
|
||||
"supports_vision": False,
|
||||
},
|
||||
"gemini-pro-vision": {
|
||||
"context_window": 16384,
|
||||
"supports_tools": False,
|
||||
"supports_vision": True,
|
||||
},
|
||||
# Gemini Ultra (when available)
|
||||
"gemini-ultra": {
|
||||
"context_window": 30720,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {
|
||||
"context_window": 1048576,
|
||||
"supports_tools": True,
|
||||
"supports_vision": True,
|
||||
}
|
||||
|
||||
# Try exact match first
|
||||
if self.model in model_configs:
|
||||
return model_configs[self.model]
|
||||
|
||||
# Try prefix match for versioned models
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(self, messages: Union[str, List[Dict[str, str]]]) -> List[Any]:
|
||||
"""Format messages for Google Gen AI SDK.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted Content objects
|
||||
"""
|
||||
if genai is None or types is None:
|
||||
return []
|
||||
|
||||
if isinstance(messages, str):
|
||||
return [
|
||||
types.Content(role="user", parts=[types.Part.from_text(text=messages)])
|
||||
]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Convert to Google Gen AI SDK format
|
||||
formatted_messages = []
|
||||
system_instruction = None
|
||||
|
||||
for msg in messages:
|
||||
role = msg["role"]
|
||||
content = msg["content"]
|
||||
|
||||
if role == "system":
|
||||
# System instruction will be handled separately
|
||||
system_instruction = content
|
||||
elif role == "user":
|
||||
formatted_messages.append(
|
||||
types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
elif role == "assistant":
|
||||
formatted_messages.append(
|
||||
types.Content(
|
||||
role="model", parts=[types.Part.from_text(text=content)]
|
||||
)
|
||||
)
|
||||
|
||||
# Store system instruction for later use
|
||||
self._system_instruction = system_instruction
|
||||
|
||||
return formatted_messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[Any]]:
|
||||
"""Format tools for Google Gen AI SDK function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
Google Gen AI SDK formatted tool definitions
|
||||
"""
|
||||
if genai is None or types is None:
|
||||
return None
|
||||
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to Google Gen AI SDK function declaration format
|
||||
function_declaration = types.FunctionDeclaration(
|
||||
name=tool.get("name", ""),
|
||||
description=tool.get("description", ""),
|
||||
parameters=tool.get("parameters", {}),
|
||||
)
|
||||
formatted_tools.append(
|
||||
types.Tool(function_declarations=[function_declaration])
|
||||
)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _build_generation_config(
|
||||
self,
|
||||
system_instruction: Optional[str] = None,
|
||||
tools: Optional[List[Any]] = None,
|
||||
) -> Any:
|
||||
"""Build Google Gen AI SDK generation config from parameters."""
|
||||
if genai is None or types is None:
|
||||
return {}
|
||||
config_dict = self.generation_config.copy()
|
||||
|
||||
# Add parameters that map to Gemini's generation config
|
||||
if self.temperature is not None:
|
||||
config_dict["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
config_dict["top_p"] = self.top_p
|
||||
|
||||
if self.top_k is not None:
|
||||
config_dict["top_k"] = self.top_k
|
||||
|
||||
if self.max_tokens is not None:
|
||||
config_dict["max_output_tokens"] = self.max_tokens
|
||||
|
||||
if self.candidate_count is not None:
|
||||
config_dict["candidate_count"] = self.candidate_count
|
||||
|
||||
if self.stop:
|
||||
config_dict["stop_sequences"] = self.stop
|
||||
|
||||
if self.stream:
|
||||
config_dict["stream"] = True
|
||||
|
||||
# Add safety settings
|
||||
if self.safety_settings:
|
||||
config_dict["safety_settings"] = self.safety_settings
|
||||
|
||||
# Add response format if specified
|
||||
if self.response_format:
|
||||
config_dict["response_modalities"] = ["TEXT"]
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
config_dict["system_instruction"] = system_instruction
|
||||
|
||||
# Add tools if present
|
||||
if tools:
|
||||
config_dict["tools"] = tools
|
||||
|
||||
return types.GenerateContentConfig(**config_dict)
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from Google Gen AI SDK response.
|
||||
|
||||
Args:
|
||||
response: Google Gen AI SDK response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
# Check if response has function calls
|
||||
if (
|
||||
not available_functions
|
||||
or not hasattr(response, "candidates")
|
||||
or not response.candidates
|
||||
):
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
candidate = response.candidates[0] if response.candidates else None
|
||||
if (
|
||||
not candidate
|
||||
or not hasattr(candidate, "content")
|
||||
or not hasattr(candidate.content, "parts")
|
||||
):
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
# Look for function call parts
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call"):
|
||||
function_call = part.function_call
|
||||
function_name = function_call.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Google Gen AI SDK provides arguments as a struct
|
||||
function_args = (
|
||||
dict(function_call.args)
|
||||
if hasattr(function_call, "args")
|
||||
else {}
|
||||
)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
# If no function calls, return text content
|
||||
return response.text if hasattr(response, "text") else str(response)
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call Google Gen AI SDK with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native gemini", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
retry_count = 0
|
||||
last_error = None
|
||||
|
||||
while retry_count <= self.max_retries:
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
system_instruction = getattr(self, "_system_instruction", None)
|
||||
|
||||
# Format tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
|
||||
# Build generation config
|
||||
generation_config = self._build_generation_config(
|
||||
system_instruction, formatted_tools
|
||||
)
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Prepare the API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"contents": formatted_messages,
|
||||
"config": generation_config,
|
||||
}
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
# Streaming response
|
||||
response_stream = self.client.models.generate_content(**api_params)
|
||||
|
||||
full_response = ""
|
||||
try:
|
||||
for chunk in response_stream:
|
||||
if hasattr(chunk, "text") and chunk.text:
|
||||
full_response += chunk.text
|
||||
except Exception as e:
|
||||
print(
|
||||
f"Streaming error (continuing with partial response): {e}"
|
||||
)
|
||||
|
||||
result = full_response or "No response content"
|
||||
else:
|
||||
# Non-streaming response
|
||||
response = self.client.models.generate_content(**api_params)
|
||||
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": messages, # Use original messages, not formatted_messages
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
last_error = e
|
||||
retry_count += 1
|
||||
|
||||
if retry_count <= self.max_retries:
|
||||
print(
|
||||
f"Gemini API call failed (attempt {retry_count}/{self.max_retries + 1}): {e}"
|
||||
)
|
||||
continue
|
||||
|
||||
# All retries exhausted
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(error=str(e)),
|
||||
)
|
||||
|
||||
raise RuntimeError(
|
||||
f"Gemini API call failed after {self.max_retries + 1} attempts: {str(e)}"
|
||||
) from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if Gemini models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 1048576)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
|
||||
def supports_vision(self) -> bool:
|
||||
"""Check if the current model supports vision capabilities."""
|
||||
return self.model_config.get("supports_vision", False)
|
||||
5
src/crewai/llms/openai/__init__.py
Normal file
5
src/crewai/llms/openai/__init__.py
Normal file
@@ -0,0 +1,5 @@
|
||||
"""OpenAI LLM implementation for CrewAI."""
|
||||
|
||||
from .chat import OpenAILLM
|
||||
|
||||
__all__ = ["OpenAILLM"]
|
||||
529
src/crewai/llms/openai/chat.py
Normal file
529
src/crewai/llms/openai/chat.py
Normal file
@@ -0,0 +1,529 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union, Type, Literal
|
||||
from openai import OpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.llm_events import (
|
||||
LLMCallCompletedEvent,
|
||||
LLMCallFailedEvent,
|
||||
LLMCallStartedEvent,
|
||||
LLMCallType,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import (
|
||||
ToolUsageStartedEvent,
|
||||
ToolUsageFinishedEvent,
|
||||
ToolUsageErrorEvent,
|
||||
)
|
||||
from datetime import datetime
|
||||
|
||||
|
||||
class OpenAILLM(BaseLLM):
|
||||
"""OpenAI LLM implementation with full LLM class compatibility."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str = "gpt-4",
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
n: Optional[int] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
max_completion_tokens: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
presence_penalty: Optional[float] = None,
|
||||
frequency_penalty: Optional[float] = None,
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Type[BaseModel]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_base: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
|
||||
stream: bool = False,
|
||||
max_retries: int = 2,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize OpenAI LLM with full compatibility.
|
||||
|
||||
Args:
|
||||
model: OpenAI model name (e.g., 'gpt-4', 'gpt-3.5-turbo')
|
||||
timeout: Request timeout in seconds
|
||||
temperature: Sampling temperature (0-2)
|
||||
top_p: Nucleus sampling parameter
|
||||
n: Number of completions to generate
|
||||
stop: Stop sequences
|
||||
max_completion_tokens: Maximum tokens in completion
|
||||
max_tokens: Maximum tokens (legacy parameter)
|
||||
presence_penalty: Presence penalty (-2 to 2)
|
||||
frequency_penalty: Frequency penalty (-2 to 2)
|
||||
logit_bias: Logit bias dictionary
|
||||
response_format: Pydantic model for structured output
|
||||
seed: Random seed for deterministic output
|
||||
logprobs: Whether to return log probabilities
|
||||
top_logprobs: Number of most likely tokens to return
|
||||
base_url: Custom API base URL
|
||||
api_base: Legacy API base parameter
|
||||
api_version: API version (for Azure)
|
||||
api_key: OpenAI API key
|
||||
callbacks: List of callback functions
|
||||
reasoning_effort: Reasoning effort for o1 models
|
||||
stream: Whether to stream responses
|
||||
max_retries: Number of retries for failed requests
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(model=model, temperature=temperature)
|
||||
|
||||
# Store all parameters for compatibility
|
||||
self.timeout = timeout
|
||||
self.top_p = top_p
|
||||
self.n = n
|
||||
self.max_completion_tokens = max_completion_tokens
|
||||
self.max_tokens = max_tokens or max_completion_tokens
|
||||
self.presence_penalty = presence_penalty
|
||||
self.frequency_penalty = frequency_penalty
|
||||
self.logit_bias = logit_bias
|
||||
self.response_format = response_format
|
||||
self.seed = seed
|
||||
self.logprobs = logprobs
|
||||
self.top_logprobs = top_logprobs
|
||||
self.api_base = api_base or base_url
|
||||
self.base_url = base_url or api_base
|
||||
self.api_version = api_version
|
||||
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
|
||||
self.callbacks = callbacks
|
||||
self.reasoning_effort = reasoning_effort
|
||||
self.stream = stream
|
||||
self.additional_params = kwargs
|
||||
self.context_window_size = 0
|
||||
|
||||
# Normalize stop parameter to match LLM class behavior
|
||||
if stop is None:
|
||||
self.stop: List[str] = []
|
||||
elif isinstance(stop, str):
|
||||
self.stop = [stop]
|
||||
else:
|
||||
self.stop = stop
|
||||
|
||||
# Initialize OpenAI client
|
||||
self.client = OpenAI(
|
||||
api_key=self.api_key,
|
||||
base_url=self.base_url,
|
||||
timeout=self.timeout,
|
||||
max_retries=max_retries,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
self.model_config = self._get_model_config()
|
||||
|
||||
def _get_model_config(self) -> Dict[str, Any]:
|
||||
"""Get model-specific configuration."""
|
||||
# Enhanced model configurations matching current LLM_CONTEXT_WINDOW_SIZES
|
||||
model_configs = {
|
||||
"gpt-4": {"context_window": 8192, "supports_tools": True},
|
||||
"gpt-4o": {"context_window": 128000, "supports_tools": True},
|
||||
"gpt-4o-mini": {"context_window": 200000, "supports_tools": True},
|
||||
"gpt-4-turbo": {"context_window": 128000, "supports_tools": True},
|
||||
"gpt-4.1": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-4.1-mini": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-4.1-nano": {"context_window": 1047576, "supports_tools": True},
|
||||
"gpt-3.5-turbo": {"context_window": 16385, "supports_tools": True},
|
||||
"o1-preview": {"context_window": 128000, "supports_tools": False},
|
||||
"o1-mini": {"context_window": 128000, "supports_tools": False},
|
||||
"o3-mini": {"context_window": 200000, "supports_tools": False},
|
||||
"o4-mini": {"context_window": 200000, "supports_tools": False},
|
||||
}
|
||||
|
||||
# Default config if model not found
|
||||
default_config = {"context_window": 4096, "supports_tools": True}
|
||||
|
||||
for model_prefix, config in model_configs.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return config
|
||||
|
||||
return default_config
|
||||
|
||||
def _format_messages(
|
||||
self, messages: Union[str, List[Dict[str, str]]]
|
||||
) -> List[Dict[str, str]]:
|
||||
"""Format messages for OpenAI API.
|
||||
|
||||
Args:
|
||||
messages: Input messages as string or list of dicts
|
||||
|
||||
Returns:
|
||||
List of properly formatted message dicts
|
||||
"""
|
||||
if isinstance(messages, str):
|
||||
return [{"role": "user", "content": messages}]
|
||||
|
||||
# Validate message format
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise ValueError(
|
||||
"Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
if msg["role"] == "system":
|
||||
# Convert system messages to assistant messages for O1
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
)
|
||||
else:
|
||||
formatted_messages.append(msg)
|
||||
return formatted_messages
|
||||
|
||||
return messages
|
||||
|
||||
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
|
||||
"""Format tools for OpenAI function calling.
|
||||
|
||||
Args:
|
||||
tools: List of tool definitions
|
||||
|
||||
Returns:
|
||||
OpenAI-formatted tool definitions
|
||||
"""
|
||||
if not tools or not self.model_config.get("supports_tools", True):
|
||||
return None
|
||||
|
||||
formatted_tools = []
|
||||
for tool in tools:
|
||||
# Convert to OpenAI tool format
|
||||
formatted_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": tool.get("name", ""),
|
||||
"description": tool.get("description", ""),
|
||||
"parameters": tool.get("parameters", {}),
|
||||
},
|
||||
}
|
||||
formatted_tools.append(formatted_tool)
|
||||
|
||||
return formatted_tools
|
||||
|
||||
def _handle_tool_calls(
|
||||
self,
|
||||
response,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Any:
|
||||
"""Handle tool calls from OpenAI response.
|
||||
|
||||
Args:
|
||||
response: OpenAI API response
|
||||
available_functions: Dict mapping function names to callables
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
Result of function execution or error message
|
||||
"""
|
||||
message = response.choices[0].message
|
||||
|
||||
if not message.tool_calls or not available_functions:
|
||||
return message.content
|
||||
|
||||
# Execute the first tool call
|
||||
tool_call = message.tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
function_args = {}
|
||||
|
||||
if function_name not in available_functions:
|
||||
return f"Error: Function '{function_name}' not found in available functions"
|
||||
|
||||
try:
|
||||
# Parse function arguments
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
# Execute function with event tracking
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageStartedEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
),
|
||||
)
|
||||
|
||||
result = fn(**function_args)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
output=result,
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
started_at=started_at,
|
||||
finished_at=datetime.now(),
|
||||
),
|
||||
)
|
||||
|
||||
# Emit success event
|
||||
event_data = {
|
||||
"response": result,
|
||||
"call_type": LLMCallType.TOOL_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except json.JSONDecodeError as e:
|
||||
error_msg = f"Error parsing function arguments: {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
except Exception as e:
|
||||
error_msg = f"Error executing function '{function_name}': {e}"
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageErrorEvent(
|
||||
tool_name=function_name,
|
||||
tool_args=function_args,
|
||||
error=error_msg,
|
||||
),
|
||||
)
|
||||
return error_msg
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call OpenAI API with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM
|
||||
tools: Optional list of tool schemas
|
||||
callbacks: Optional callbacks to execute
|
||||
available_functions: Optional dict of available functions
|
||||
from_task: Optional task context
|
||||
from_agent: Optional agent context
|
||||
|
||||
Returns:
|
||||
LLM response or tool execution result
|
||||
|
||||
Raises:
|
||||
ValueError: If messages format is invalid
|
||||
RuntimeError: If API call fails
|
||||
"""
|
||||
# Emit call started event
|
||||
print("calling from native openai", messages)
|
||||
assert hasattr(crewai_event_bus, "emit")
|
||||
|
||||
# Prepare event data
|
||||
started_event_data = {
|
||||
"messages": messages,
|
||||
"tools": tools,
|
||||
"callbacks": callbacks,
|
||||
"available_functions": available_functions,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
started_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
started_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallStartedEvent(**started_event_data),
|
||||
)
|
||||
|
||||
try:
|
||||
# Format messages
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
# Prepare API call parameters
|
||||
api_params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
}
|
||||
|
||||
# Add optional parameters
|
||||
if self.temperature is not None:
|
||||
api_params["temperature"] = self.temperature
|
||||
|
||||
if self.top_p is not None:
|
||||
api_params["top_p"] = self.top_p
|
||||
|
||||
if self.n is not None:
|
||||
api_params["n"] = self.n
|
||||
|
||||
if self.max_tokens is not None:
|
||||
api_params["max_tokens"] = self.max_tokens
|
||||
|
||||
if self.presence_penalty is not None:
|
||||
api_params["presence_penalty"] = self.presence_penalty
|
||||
|
||||
if self.frequency_penalty is not None:
|
||||
api_params["frequency_penalty"] = self.frequency_penalty
|
||||
|
||||
if self.logit_bias is not None:
|
||||
api_params["logit_bias"] = self.logit_bias
|
||||
|
||||
if self.seed is not None:
|
||||
api_params["seed"] = self.seed
|
||||
|
||||
if self.logprobs is not None:
|
||||
api_params["logprobs"] = self.logprobs
|
||||
|
||||
if self.top_logprobs is not None:
|
||||
api_params["top_logprobs"] = self.top_logprobs
|
||||
|
||||
if self.stop:
|
||||
api_params["stop"] = self.stop
|
||||
|
||||
if self.response_format is not None:
|
||||
# Handle structured output for Pydantic models
|
||||
if hasattr(self.response_format, "model_json_schema"):
|
||||
api_params["response_format"] = {
|
||||
"type": "json_schema",
|
||||
"json_schema": {
|
||||
"name": self.response_format.__name__,
|
||||
"schema": self.response_format.model_json_schema(),
|
||||
"strict": True,
|
||||
},
|
||||
}
|
||||
else:
|
||||
api_params["response_format"] = self.response_format
|
||||
|
||||
if self.reasoning_effort is not None and "o1" in self.model:
|
||||
api_params["reasoning_effort"] = self.reasoning_effort
|
||||
|
||||
# Add tools if provided and supported
|
||||
formatted_tools = self._format_tools(tools)
|
||||
if formatted_tools:
|
||||
api_params["tools"] = formatted_tools
|
||||
api_params["tool_choice"] = "auto"
|
||||
|
||||
# Execute callbacks before API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_start"):
|
||||
callback.on_llm_start(
|
||||
serialized={"name": self.__class__.__name__},
|
||||
prompts=[str(formatted_messages)],
|
||||
)
|
||||
|
||||
# Make API call
|
||||
if self.stream:
|
||||
response = self.client.chat.completions.create(
|
||||
stream=True, **api_params
|
||||
)
|
||||
# Handle streaming (simplified for now)
|
||||
full_response = ""
|
||||
for chunk in response:
|
||||
if (
|
||||
hasattr(chunk.choices[0].delta, "content")
|
||||
and chunk.choices[0].delta.content
|
||||
):
|
||||
full_response += chunk.choices[0].delta.content
|
||||
result = full_response
|
||||
else:
|
||||
response = self.client.chat.completions.create(**api_params)
|
||||
# Handle tool calls if present
|
||||
result = self._handle_tool_calls(
|
||||
response, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
# If no tool calls, return text content
|
||||
if result == response.choices[0].message.content:
|
||||
result = response.choices[0].message.content or ""
|
||||
|
||||
# Execute callbacks after API call
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_end"):
|
||||
callback.on_llm_end(response=result)
|
||||
|
||||
# Emit completion event
|
||||
completion_event_data = {
|
||||
"messages": formatted_messages,
|
||||
"response": result,
|
||||
"call_type": LLMCallType.LLM_CALL,
|
||||
"model": self.model,
|
||||
}
|
||||
if from_task is not None:
|
||||
completion_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
completion_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallCompletedEvent(**completion_event_data),
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
# Execute error callbacks
|
||||
if callbacks:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "on_llm_error"):
|
||||
callback.on_llm_error(error=e)
|
||||
|
||||
# Emit failed event
|
||||
failed_event_data = {
|
||||
"error": str(e),
|
||||
}
|
||||
if from_task is not None:
|
||||
failed_event_data["from_task"] = from_task
|
||||
if from_agent is not None:
|
||||
failed_event_data["from_agent"] = from_agent
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LLMCallFailedEvent(**failed_event_data),
|
||||
)
|
||||
|
||||
raise RuntimeError(f"OpenAI API call failed: {str(e)}") from e
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if OpenAI models support stop words."""
|
||||
return True
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the current model."""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
# Use 85% of the context window like the original LLM class
|
||||
context_window = self.model_config.get("context_window", 4096)
|
||||
self.context_window_size = int(context_window * 0.85)
|
||||
return self.context_window_size
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the current model supports function calling."""
|
||||
return self.model_config.get("supports_tools", True)
|
||||
@@ -7,39 +7,71 @@ from crewai.llm import LLM, BaseLLM
|
||||
|
||||
def create_llm(
|
||||
llm_value: Union[str, LLM, Any, None] = None,
|
||||
prefer_native: Optional[bool] = None,
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
"""
|
||||
Creates or returns an LLM instance based on the given llm_value.
|
||||
Now supports provider prefixes like 'openai/gpt-4' for native implementations.
|
||||
|
||||
Args:
|
||||
llm_value (str | BaseLLM | Any | None):
|
||||
- str: The model name (e.g., "gpt-4").
|
||||
- str: The model name (e.g., "gpt-4" or "openai/gpt-4").
|
||||
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
|
||||
- Any: Attempt to extract known attributes like model_name, temperature, etc.
|
||||
- None: Use environment-based or fallback default model.
|
||||
prefer_native (bool | None):
|
||||
- True: Use native provider implementations when available
|
||||
- False: Always use LiteLLM implementation
|
||||
- None: Use environment variable CREWAI_PREFER_NATIVE_LLMS (default: True)
|
||||
- Note: Provider prefixes (openai/, anthropic/) override this setting
|
||||
|
||||
Returns:
|
||||
A BaseLLM instance if successful, or None if something fails.
|
||||
|
||||
Examples:
|
||||
create_llm("gpt-4") # Uses LiteLLM or native based on prefer_native
|
||||
create_llm("openai/gpt-4") # Always uses native OpenAI implementation
|
||||
create_llm("anthropic/claude-3-sonnet") # Future: native Anthropic
|
||||
"""
|
||||
|
||||
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
|
||||
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
|
||||
return llm_value
|
||||
|
||||
# 2) If llm_value is a string (model name)
|
||||
# 2) Determine if we should prefer native implementations (unless provider prefix is used)
|
||||
if prefer_native is None:
|
||||
prefer_native = os.getenv("CREWAI_PREFER_NATIVE_LLMS", "true").lower() in (
|
||||
"true",
|
||||
"1",
|
||||
"yes",
|
||||
)
|
||||
|
||||
# 3) If llm_value is a string (model name)
|
||||
if isinstance(llm_value, str):
|
||||
try:
|
||||
# Provider prefix (openai/, anthropic/) always takes precedence
|
||||
if "/" in llm_value:
|
||||
created_llm = LLM(model=llm_value) # LLM class handles routing
|
||||
return created_llm
|
||||
|
||||
# Try native implementation first if preferred and no prefix
|
||||
if prefer_native:
|
||||
native_llm = _create_native_llm(llm_value)
|
||||
if native_llm:
|
||||
return native_llm
|
||||
|
||||
# Fallback to LiteLLM
|
||||
created_llm = LLM(model=llm_value)
|
||||
return created_llm
|
||||
except Exception as e:
|
||||
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
|
||||
return None
|
||||
|
||||
# 3) If llm_value is None, parse environment variables or use default
|
||||
# 4) If llm_value is None, parse environment variables or use default
|
||||
if llm_value is None:
|
||||
return _llm_via_environment_or_fallback()
|
||||
return _llm_via_environment_or_fallback(prefer_native)
|
||||
|
||||
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
# 5) Otherwise, attempt to extract relevant attributes from an unknown object
|
||||
try:
|
||||
# Extract attributes with explicit types
|
||||
model = (
|
||||
@@ -48,6 +80,8 @@ def create_llm(
|
||||
or getattr(llm_value, "deployment_name", None)
|
||||
or str(llm_value)
|
||||
)
|
||||
|
||||
# Extract other parameters
|
||||
temperature: Optional[float] = getattr(llm_value, "temperature", None)
|
||||
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
|
||||
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
|
||||
@@ -56,6 +90,7 @@ def create_llm(
|
||||
base_url: Optional[str] = getattr(llm_value, "base_url", None)
|
||||
api_base: Optional[str] = getattr(llm_value, "api_base", None)
|
||||
|
||||
# Use LLM class constructor which handles routing
|
||||
created_llm = LLM(
|
||||
model=model,
|
||||
temperature=temperature,
|
||||
@@ -72,9 +107,94 @@ def create_llm(
|
||||
return None
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
def _create_native_llm(model: str, **kwargs) -> Optional[BaseLLM]:
|
||||
"""
|
||||
Create a native LLM implementation based on the model name.
|
||||
|
||||
Args:
|
||||
model: The model name (e.g., 'gpt-4', 'claude-3-sonnet')
|
||||
**kwargs: Additional parameters for the LLM
|
||||
|
||||
Returns:
|
||||
Native LLM instance if supported, None otherwise
|
||||
"""
|
||||
try:
|
||||
# OpenAI models
|
||||
if _is_openai_model(model):
|
||||
from crewai.llms.openai import OpenAILLM
|
||||
|
||||
return OpenAILLM(model=model, **kwargs)
|
||||
|
||||
# Claude models
|
||||
if _is_claude_model(model):
|
||||
from crewai.llms.anthropic import ClaudeLLM
|
||||
|
||||
return ClaudeLLM(model=model, **kwargs)
|
||||
|
||||
# Gemini models
|
||||
if _is_gemini_model(model):
|
||||
from crewai.llms.google import GeminiLLM
|
||||
|
||||
return GeminiLLM(model=model, **kwargs)
|
||||
|
||||
# No native implementation found
|
||||
return None
|
||||
|
||||
except Exception as e:
|
||||
print(f"Failed to create native LLM for model '{model}': {e}")
|
||||
return None
|
||||
|
||||
|
||||
def _is_openai_model(model: str) -> bool:
|
||||
"""Check if a model is from OpenAI."""
|
||||
openai_prefixes = (
|
||||
"gpt-",
|
||||
"text-davinci",
|
||||
"text-curie",
|
||||
"text-babbage",
|
||||
"text-ada",
|
||||
"davinci",
|
||||
"curie",
|
||||
"babbage",
|
||||
"ada",
|
||||
"o1-",
|
||||
"o3-",
|
||||
"o4-",
|
||||
"chatgpt-",
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in openai_prefixes)
|
||||
|
||||
|
||||
def _is_claude_model(model: str) -> bool:
|
||||
"""Check if a model is from Anthropic (Claude)."""
|
||||
claude_prefixes = (
|
||||
"claude-",
|
||||
"claude", # For cases like just "claude"
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in claude_prefixes)
|
||||
|
||||
|
||||
def _is_gemini_model(model: str) -> bool:
|
||||
"""Check if a model is from Google (Gemini)."""
|
||||
gemini_prefixes = (
|
||||
"gemini-",
|
||||
"gemini", # For cases like just "gemini"
|
||||
)
|
||||
|
||||
model_lower = model.lower()
|
||||
return any(model_lower.startswith(prefix) for prefix in gemini_prefixes)
|
||||
|
||||
|
||||
def _llm_via_environment_or_fallback(
|
||||
prefer_native: bool = True,
|
||||
) -> Optional[LLM | BaseLLM]:
|
||||
"""
|
||||
Helper function: if llm_value is None, we load environment variables or fallback default model.
|
||||
Now with native provider support.
|
||||
"""
|
||||
model_name = (
|
||||
os.environ.get("MODEL")
|
||||
@@ -83,7 +203,13 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
|
||||
or DEFAULT_LLM_MODEL
|
||||
)
|
||||
|
||||
# Initialize parameters with correct types
|
||||
# Try native implementation first if preferred
|
||||
if prefer_native:
|
||||
native_llm = _create_native_llm(model_name)
|
||||
if native_llm:
|
||||
return native_llm
|
||||
|
||||
# Initialize parameters with correct types (original logic continues)
|
||||
model: str = model_name
|
||||
temperature: Optional[float] = None
|
||||
max_tokens: Optional[int] = None
|
||||
|
||||
Reference in New Issue
Block a user