Compare commits

...

1 Commits

Author SHA1 Message Date
lorenzejay
79f716162b feat: enhance LLM routing and support for native implementations
- Implemented provider prefix routing for LLMs, allowing for native implementations of OpenAI, Anthropic, and Google Gemini.
- Added fallback to LiteLLM if native implementations are unavailable.
- Updated create_llm function to support provider prefixes and native preference settings.
- Introduced new LLM classes for Claude and Gemini, ensuring compatibility with existing LLM interfaces.
- Enhanced documentation for LLM utility functions to clarify usage and examples.
2025-08-22 13:39:14 -07:00
9 changed files with 2136 additions and 9 deletions

View File

@@ -316,6 +316,143 @@ class LLM(BaseLLM):
stream: bool = False,
**kwargs,
):
# Check for provider prefixes and route to native implementations
if "/" in model:
provider, actual_model = model.split("/", 1)
# Route to OpenAI native implementation
if provider.lower() == "openai":
try:
from crewai.llms.openai import OpenAILLM
# Create native OpenAI instance with all the same parameters
native_llm = OpenAILLM(
model=actual_model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
stream=stream,
**kwargs,
)
# Replace this LLM instance with the native one
self.__class__ = native_llm.__class__
self.__dict__.update(native_llm.__dict__)
return
except ImportError:
# Fall back to LiteLLM if native implementation unavailable
print(
f"Native OpenAI implementation not available, using LiteLLM for {model}"
)
model = actual_model # Remove the prefix for LiteLLM
# Route to Claude native implementation
elif provider.lower() == "anthropic":
try:
from crewai.llms.anthropic import ClaudeLLM
# Create native Claude instance with all the same parameters
native_llm = ClaudeLLM(
model=actual_model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
stream=stream,
**kwargs,
)
# Replace this LLM instance with the native one
self.__class__ = native_llm.__class__
self.__dict__.update(native_llm.__dict__)
return
except ImportError:
# Fall back to LiteLLM if native implementation unavailable
print(
f"Native Claude implementation not available, using LiteLLM for {model}"
)
model = actual_model # Remove the prefix for LiteLLM
# Route to Gemini native implementation
elif provider.lower() == "google":
try:
from crewai.llms.google import GeminiLLM
# Create native Gemini instance with all the same parameters
native_llm = GeminiLLM(
model=actual_model,
timeout=timeout,
temperature=temperature,
top_p=top_p,
n=n,
stop=stop,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
presence_penalty=presence_penalty,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
response_format=response_format,
seed=seed,
logprobs=logprobs,
top_logprobs=top_logprobs,
base_url=base_url,
api_base=api_base,
api_version=api_version,
api_key=api_key,
callbacks=callbacks,
reasoning_effort=reasoning_effort,
stream=stream,
**kwargs,
)
# Replace this LLM instance with the native one
self.__class__ = native_llm.__class__
self.__dict__.update(native_llm.__dict__)
return
except ImportError:
# Fall back to LiteLLM if native implementation unavailable
print(
f"Native Gemini implementation not available, using LiteLLM for {model}"
)
model = actual_model # Remove the prefix for LiteLLM
# Continue with original LiteLLM initialization
self.model = model
self.timeout = timeout
self.temperature = temperature
@@ -1139,7 +1276,11 @@ class LLM(BaseLLM):
# TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917
# Ollama doesn't supports last message to be 'assistant'
if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant":
if (
"ollama" in self.model.lower()
and messages
and messages[-1]["role"] == "assistant"
):
return messages + [{"role": "user", "content": ""}]
# Handle Anthropic models

View File

@@ -1 +1,11 @@
"""LLM implementations for crewAI."""
"""CrewAI LLM implementations."""
from .base_llm import BaseLLM
from .openai import OpenAILLM
from .anthropic import ClaudeLLM
from .google import GeminiLLM
# Import the main LLM class for backward compatibility
__all__ = ["BaseLLM", "OpenAILLM", "ClaudeLLM", "GeminiLLM"]

View File

@@ -0,0 +1,5 @@
"""Anthropic Claude LLM implementation for CrewAI."""
from .claude import ClaudeLLM
__all__ = ["ClaudeLLM"]

View File

@@ -0,0 +1,569 @@
import os
from typing import Any, Dict, List, Optional, Union, Type, Literal
from anthropic import Anthropic
from pydantic import BaseModel
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
)
from crewai.utilities.events.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
)
from datetime import datetime
class ClaudeLLM(BaseLLM):
"""Anthropic Claude LLM implementation with full LLM class compatibility."""
def __init__(
self,
model: str = "claude-3-5-sonnet-20241022",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None, # Not supported by Claude but kept for compatibility
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[
float
] = None, # Not supported but kept for compatibility
frequency_penalty: Optional[
float
] = None, # Not supported but kept for compatibility
logit_bias: Optional[
Dict[int, float]
] = None, # Not supported but kept for compatibility
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None, # Not supported but kept for compatibility
logprobs: Optional[int] = None, # Not supported but kept for compatibility
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None, # Not used by Anthropic
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[
Literal["none", "low", "medium", "high"]
] = None, # Not used by Claude
stream: bool = False,
max_retries: int = 2,
# Claude-specific parameters
thinking_mode: bool = False, # Enable Claude's thinking mode
top_k: Optional[int] = None, # Claude-specific sampling parameter
**kwargs,
):
"""Initialize Claude LLM with full compatibility.
Args:
model: Claude model name (e.g., 'claude-3-5-sonnet-20241022')
timeout: Request timeout in seconds
temperature: Sampling temperature (0-1 for Claude)
top_p: Nucleus sampling parameter
n: Number of completions (not supported by Claude, kept for compatibility)
stop: Stop sequences
max_completion_tokens: Maximum tokens in completion
max_tokens: Maximum tokens (legacy parameter)
presence_penalty: Not supported by Claude, kept for compatibility
frequency_penalty: Not supported by Claude, kept for compatibility
logit_bias: Not supported by Claude, kept for compatibility
response_format: Pydantic model for structured output
seed: Not supported by Claude, kept for compatibility
logprobs: Not supported by Claude, kept for compatibility
top_logprobs: Not supported by Claude, kept for compatibility
base_url: Custom API base URL
api_base: Legacy API base parameter
api_version: Not used by Anthropic
api_key: Anthropic API key
callbacks: List of callback functions
reasoning_effort: Not used by Claude, kept for compatibility
stream: Whether to stream responses
max_retries: Number of retries for failed requests
thinking_mode: Enable Claude's thinking mode (if supported)
top_k: Claude-specific top-k sampling parameter
**kwargs: Additional parameters
"""
super().__init__(model=model, temperature=temperature)
# Store all parameters for compatibility
self.timeout = timeout
self.top_p = top_p
self.n = n # Claude doesn't support n>1, but we store it for compatibility
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens or max_completion_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.api_base = api_base or base_url
self.base_url = base_url or api_base
self.api_version = api_version
self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY")
self.callbacks = callbacks
self.reasoning_effort = reasoning_effort
self.stream = stream
self.additional_params = kwargs
self.context_window_size = 0
# Claude-specific parameters
self.thinking_mode = thinking_mode
self.top_k = top_k
# Normalize stop parameter to match LLM class behavior
if stop is None:
self.stop: List[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
# Initialize Anthropic client
client_kwargs = {}
if self.api_key:
client_kwargs["api_key"] = self.api_key
if self.base_url:
client_kwargs["base_url"] = self.base_url
if self.timeout:
client_kwargs["timeout"] = self.timeout
if max_retries:
client_kwargs["max_retries"] = max_retries
# Add any additional kwargs that might be relevant to the client
for key, value in kwargs.items():
if key not in ["thinking_mode", "top_k"]: # Exclude our custom params
client_kwargs[key] = value
self.client = Anthropic(**client_kwargs)
self.model_config = self._get_model_config()
def _get_model_config(self) -> Dict[str, Any]:
"""Get model-specific configuration for Claude models."""
# Claude model configurations based on Anthropic's documentation
model_configs = {
# Claude 3.5 Sonnet
"claude-3-5-sonnet-20241022": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
"claude-3-5-sonnet-20240620": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
# Claude 3.5 Haiku
"claude-3-5-haiku-20241022": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
# Claude 3 Opus
"claude-3-opus-20240229": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
# Claude 3 Sonnet
"claude-3-sonnet-20240229": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
# Claude 3 Haiku
"claude-3-haiku-20240307": {
"context_window": 200000,
"supports_tools": True,
"supports_vision": True,
},
# Claude 2.1
"claude-2.1": {
"context_window": 200000,
"supports_tools": False,
"supports_vision": False,
},
"claude-2": {
"context_window": 100000,
"supports_tools": False,
"supports_vision": False,
},
# Claude Instant
"claude-instant-1.2": {
"context_window": 100000,
"supports_tools": False,
"supports_vision": False,
},
}
# Default config if model not found
default_config = {
"context_window": 200000,
"supports_tools": True,
"supports_vision": False,
}
# Try exact match first
if self.model in model_configs:
return model_configs[self.model]
# Try prefix match for versioned models
for model_prefix, config in model_configs.items():
if self.model.startswith(model_prefix):
return config
return default_config
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Format messages for Anthropic API.
Args:
messages: Input messages as string or list of dicts
Returns:
List of properly formatted message dicts
"""
if isinstance(messages, str):
return [{"role": "user", "content": messages}]
# Validate message format
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
raise ValueError(
"Each message must be a dict with 'role' and 'content' keys"
)
# Claude requires alternating user/assistant messages and cannot start with system
formatted_messages = []
system_message = None
for msg in messages:
if msg["role"] == "system":
# Store system message separately - Claude handles it differently
if system_message is None:
system_message = msg["content"]
else:
system_message += "\n\n" + msg["content"]
else:
formatted_messages.append(msg)
# Ensure messages alternate and start with user
if formatted_messages and formatted_messages[0]["role"] != "user":
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
# Store system message for later use
self._system_message = system_message
return formatted_messages
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
"""Format tools for Claude function calling.
Args:
tools: List of tool definitions
Returns:
Claude-formatted tool definitions
"""
if not tools or not self.model_config.get("supports_tools", True):
return None
formatted_tools = []
for tool in tools:
# Convert to Claude tool format
formatted_tool = {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"input_schema": tool.get("parameters", {}),
}
formatted_tools.append(formatted_tool)
return formatted_tools
def _handle_tool_calls(
self,
response,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Any:
"""Handle tool calls from Claude response.
Args:
response: Claude API response
available_functions: Dict mapping function names to callables
from_task: Optional task context
from_agent: Optional agent context
Returns:
Result of function execution or error message
"""
# Claude returns tool use in content blocks
if not hasattr(response, "content") or not available_functions:
return response.content[0].text if response.content else ""
# Look for tool use blocks
for content_block in response.content:
if hasattr(content_block, "type") and content_block.type == "tool_use":
function_name = content_block.name
function_args = {}
if function_name not in available_functions:
return f"Error: Function '{function_name}' not found in available functions"
try:
# Claude provides arguments as a dict
function_args = content_block.input
fn = available_functions[function_name]
# Execute function with event tracking
assert hasattr(crewai_event_bus, "emit")
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=function_name,
tool_args=function_args,
),
)
result = fn(**function_args)
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=function_name,
tool_args=function_args,
started_at=started_at,
finished_at=datetime.now(),
),
)
# Emit success event
event_data = {
"response": result,
"call_type": LLMCallType.TOOL_CALL,
"model": self.model,
}
if from_task is not None:
event_data["from_task"] = from_task
if from_agent is not None:
event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**event_data),
)
return result
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e}"
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=error_msg,
),
)
return error_msg
# If no tool calls, return text content
return response.content[0].text if response.content else ""
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
"""Call Claude API with the given messages.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional callbacks to execute
available_functions: Optional dict of available functions
from_task: Optional task context
from_agent: Optional agent context
Returns:
LLM response or tool execution result
Raises:
ValueError: If messages format is invalid
RuntimeError: If API call fails
"""
# Emit call started event
print("calling from native claude", messages)
assert hasattr(crewai_event_bus, "emit")
# Prepare event data
started_event_data = {
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"model": self.model,
}
if from_task is not None:
started_event_data["from_task"] = from_task
if from_agent is not None:
started_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(**started_event_data),
)
try:
# Format messages
formatted_messages = self._format_messages(messages)
system_message = getattr(self, "_system_message", None)
# Prepare API call parameters
api_params = {
"model": self.model,
"messages": formatted_messages,
"max_tokens": self.max_tokens or 4000, # Claude requires max_tokens
}
# Add system message if present
if system_message:
api_params["system"] = system_message
# Add optional parameters that Claude supports
if self.temperature is not None:
api_params["temperature"] = self.temperature
if self.top_p is not None:
api_params["top_p"] = self.top_p
if self.top_k is not None:
api_params["top_k"] = self.top_k
if self.stop:
api_params["stop_sequences"] = self.stop
# Add tools if provided and supported
formatted_tools = self._format_tools(tools)
if formatted_tools:
api_params["tools"] = formatted_tools
# Execute callbacks before API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_start"):
callback.on_llm_start(
serialized={"name": self.__class__.__name__},
prompts=[str(formatted_messages)],
)
# Make API call
if self.stream:
response = self.client.messages.create(stream=True, **api_params)
# Handle streaming (simplified implementation)
full_response = ""
try:
for event in response:
if hasattr(event, "type"):
if event.type == "content_block_delta":
if hasattr(event, "delta") and hasattr(
event.delta, "text"
):
full_response += event.delta.text
except Exception as e:
# If streaming fails, fall back to the response we have
print(f"Streaming error (continuing with partial response): {e}")
result = full_response or "No response content"
else:
response = self.client.messages.create(**api_params)
# Handle tool calls if present
result = self._handle_tool_calls(
response, available_functions, from_task, from_agent
)
# Execute callbacks after API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_end"):
callback.on_llm_end(response=result)
# Emit completion event
completion_event_data = {
"messages": formatted_messages,
"response": result,
"call_type": LLMCallType.LLM_CALL,
"model": self.model,
}
if from_task is not None:
completion_event_data["from_task"] = from_task
if from_agent is not None:
completion_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**completion_event_data),
)
return result
except Exception as e:
# Execute error callbacks
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_error"):
callback.on_llm_error(error=e)
# Emit failed event
failed_event_data = {
"error": str(e),
}
if from_task is not None:
failed_event_data["from_task"] = from_task
if from_agent is not None:
failed_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(**failed_event_data),
)
raise RuntimeError(f"Claude API call failed: {str(e)}") from e
def supports_stop_words(self) -> bool:
"""Check if Claude models support stop words."""
return True
def get_context_window_size(self) -> int:
"""Get the context window size for the current model."""
if self.context_window_size != 0:
return self.context_window_size
# Use 85% of the context window like the original LLM class
context_window = self.model_config.get("context_window", 200000)
self.context_window_size = int(context_window * 0.85)
return self.context_window_size
def supports_function_calling(self) -> bool:
"""Check if the current model supports function calling."""
return self.model_config.get("supports_tools", True)
def supports_vision(self) -> bool:
"""Check if the current model supports vision capabilities."""
return self.model_config.get("supports_vision", False)

View File

@@ -0,0 +1,5 @@
"""Google Gemini LLM implementation for CrewAI."""
from .gemini import GeminiLLM
__all__ = ["GeminiLLM"]

View File

@@ -0,0 +1,737 @@
import os
from typing import Any, Dict, List, Optional, Union, Type, Literal, TYPE_CHECKING
from pydantic import BaseModel
if TYPE_CHECKING:
from google import genai
from google.genai import types
try:
from google import genai
from google.genai import types
except ImportError:
genai = None
types = None
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
)
from crewai.utilities.events.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
)
from datetime import datetime
class GeminiLLM(BaseLLM):
"""Google Gemini LLM implementation using the official Google Gen AI Python SDK."""
def __init__(
self,
model: str = "gemini-1.5-pro",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None, # Not supported by Gemini but kept for compatibility
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[
float
] = None, # Not supported but kept for compatibility
frequency_penalty: Optional[
float
] = None, # Not supported but kept for compatibility
logit_bias: Optional[
Dict[int, float]
] = None, # Not supported but kept for compatibility
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None, # Not supported but kept for compatibility
logprobs: Optional[int] = None, # Not supported but kept for compatibility
top_logprobs: Optional[int] = None, # Not supported but kept for compatibility
base_url: Optional[str] = None, # Not used by Gemini
api_base: Optional[str] = None, # Not used by Gemini
api_version: Optional[str] = None, # Not used by Gemini
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[
Literal["none", "low", "medium", "high"]
] = None, # Not used by Gemini
stream: bool = False,
max_retries: int = 2,
# Gemini-specific parameters
top_k: Optional[int] = None, # Gemini top-k sampling parameter
candidate_count: int = 1, # Number of response candidates
safety_settings: Optional[
List[Dict[str, Any]]
] = None, # Gemini safety settings
generation_config: Optional[
Dict[str, Any]
] = None, # Additional generation config
# Vertex AI parameters
use_vertex_ai: bool = False,
project_id: Optional[str] = None,
location: str = "us-central1",
**kwargs,
):
"""Initialize Gemini LLM with the official Google Gen AI SDK.
Args:
model: Gemini model name (e.g., 'gemini-1.5-pro', 'gemini-2.0-flash-001')
timeout: Request timeout in seconds
temperature: Sampling temperature (0-2 for Gemini)
top_p: Nucleus sampling parameter
n: Number of completions (not supported by Gemini, kept for compatibility)
stop: Stop sequences
max_completion_tokens: Maximum tokens in completion
max_tokens: Maximum tokens (legacy parameter)
presence_penalty: Not supported by Gemini, kept for compatibility
frequency_penalty: Not supported by Gemini, kept for compatibility
logit_bias: Not supported by Gemini, kept for compatibility
response_format: Pydantic model for structured output
seed: Not supported by Gemini, kept for compatibility
logprobs: Not supported by Gemini, kept for compatibility
top_logprobs: Not supported by Gemini, kept for compatibility
base_url: Not used by Gemini
api_base: Not used by Gemini
api_version: Not used by Gemini
api_key: Google AI API key
callbacks: List of callback functions
reasoning_effort: Not used by Gemini, kept for compatibility
stream: Whether to stream responses
max_retries: Number of retries for failed requests
top_k: Gemini-specific top-k sampling parameter
candidate_count: Number of response candidates to generate
safety_settings: Gemini safety settings configuration
generation_config: Additional Gemini generation configuration
use_vertex_ai: Whether to use Vertex AI instead of Gemini API
project_id: Google Cloud project ID (required for Vertex AI)
location: Google Cloud region (default: us-central1)
**kwargs: Additional parameters
"""
# Check if Google Gen AI SDK is available
if genai is None or types is None:
raise ImportError(
"Google Gen AI Python SDK is required. Please install it with: "
"pip install google-genai"
)
super().__init__(model=model, temperature=temperature)
# Store all parameters for compatibility
self.timeout = timeout
self.top_p = top_p
self.n = n
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens or max_completion_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.api_base = api_base
self.base_url = base_url
self.api_version = api_version
self.callbacks = callbacks
self.reasoning_effort = reasoning_effort
self.stream = stream
self.additional_params = kwargs
self.context_window_size = 0
self.max_retries = max_retries
# Gemini-specific parameters
self.top_k = top_k
self.candidate_count = candidate_count
self.safety_settings = safety_settings or []
self.generation_config = generation_config or {}
# Vertex AI parameters
self.use_vertex_ai = use_vertex_ai
self.project_id = project_id or os.getenv("GOOGLE_CLOUD_PROJECT")
self.location = location
# API key handling
self.api_key = (
api_key
or os.getenv("GOOGLE_AI_API_KEY")
or os.getenv("GEMINI_API_KEY")
or os.getenv("GOOGLE_API_KEY")
)
# Normalize stop parameter to match LLM class behavior
if stop is None:
self.stop: List[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
# Initialize client attribute
self.client: Any = None
# Initialize the Google Gen AI client
self._initialize_client()
self.model_config = self._get_model_config()
def _initialize_client(self):
"""Initialize the Google Gen AI client."""
if genai is None or types is None:
return
try:
if self.use_vertex_ai:
if not self.project_id:
raise ValueError(
"project_id is required when use_vertex_ai=True. "
"Set it directly or via GOOGLE_CLOUD_PROJECT environment variable."
)
self.client = genai.Client(
vertexai=True,
project=self.project_id,
location=self.location,
)
else:
if not self.api_key:
raise ValueError(
"API key is required for Gemini Developer API. "
"Set it via api_key parameter or GOOGLE_AI_API_KEY/GEMINI_API_KEY environment variable."
)
self.client = genai.Client(api_key=self.api_key)
except Exception as e:
raise RuntimeError(
f"Failed to initialize Google Gen AI client: {str(e)}"
) from e
def _get_model_config(self) -> Dict[str, Any]:
"""Get model-specific configuration for Gemini models."""
# Gemini model configurations based on Google's documentation
model_configs = {
# Gemini 2.0 Flash (latest)
"gemini-2.0-flash": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-2.0-flash-001": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-2.0-flash-exp": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
# Gemini 1.5 Pro
"gemini-1.5-pro": {
"context_window": 2097152,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-pro-002": {
"context_window": 2097152,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-pro-001": {
"context_window": 2097152,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-pro-exp-0827": {
"context_window": 2097152,
"supports_tools": True,
"supports_vision": True,
},
# Gemini 1.5 Flash
"gemini-1.5-flash": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-flash-002": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-flash-001": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-flash-8b": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
"gemini-1.5-flash-8b-exp-0827": {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
},
# Legacy Gemini Pro
"gemini-pro": {
"context_window": 30720,
"supports_tools": True,
"supports_vision": False,
},
"gemini-pro-vision": {
"context_window": 16384,
"supports_tools": False,
"supports_vision": True,
},
# Gemini Ultra (when available)
"gemini-ultra": {
"context_window": 30720,
"supports_tools": True,
"supports_vision": True,
},
}
# Default config if model not found
default_config = {
"context_window": 1048576,
"supports_tools": True,
"supports_vision": True,
}
# Try exact match first
if self.model in model_configs:
return model_configs[self.model]
# Try prefix match for versioned models
for model_prefix, config in model_configs.items():
if self.model.startswith(model_prefix):
return config
return default_config
def _format_messages(self, messages: Union[str, List[Dict[str, str]]]) -> List[Any]:
"""Format messages for Google Gen AI SDK.
Args:
messages: Input messages as string or list of dicts
Returns:
List of properly formatted Content objects
"""
if genai is None or types is None:
return []
if isinstance(messages, str):
return [
types.Content(role="user", parts=[types.Part.from_text(text=messages)])
]
# Validate message format
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
raise ValueError(
"Each message must be a dict with 'role' and 'content' keys"
)
# Convert to Google Gen AI SDK format
formatted_messages = []
system_instruction = None
for msg in messages:
role = msg["role"]
content = msg["content"]
if role == "system":
# System instruction will be handled separately
system_instruction = content
elif role == "user":
formatted_messages.append(
types.Content(
role="user", parts=[types.Part.from_text(text=content)]
)
)
elif role == "assistant":
formatted_messages.append(
types.Content(
role="model", parts=[types.Part.from_text(text=content)]
)
)
# Store system instruction for later use
self._system_instruction = system_instruction
return formatted_messages
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[Any]]:
"""Format tools for Google Gen AI SDK function calling.
Args:
tools: List of tool definitions
Returns:
Google Gen AI SDK formatted tool definitions
"""
if genai is None or types is None:
return None
if not tools or not self.model_config.get("supports_tools", True):
return None
formatted_tools = []
for tool in tools:
# Convert to Google Gen AI SDK function declaration format
function_declaration = types.FunctionDeclaration(
name=tool.get("name", ""),
description=tool.get("description", ""),
parameters=tool.get("parameters", {}),
)
formatted_tools.append(
types.Tool(function_declarations=[function_declaration])
)
return formatted_tools
def _build_generation_config(
self,
system_instruction: Optional[str] = None,
tools: Optional[List[Any]] = None,
) -> Any:
"""Build Google Gen AI SDK generation config from parameters."""
if genai is None or types is None:
return {}
config_dict = self.generation_config.copy()
# Add parameters that map to Gemini's generation config
if self.temperature is not None:
config_dict["temperature"] = self.temperature
if self.top_p is not None:
config_dict["top_p"] = self.top_p
if self.top_k is not None:
config_dict["top_k"] = self.top_k
if self.max_tokens is not None:
config_dict["max_output_tokens"] = self.max_tokens
if self.candidate_count is not None:
config_dict["candidate_count"] = self.candidate_count
if self.stop:
config_dict["stop_sequences"] = self.stop
if self.stream:
config_dict["stream"] = True
# Add safety settings
if self.safety_settings:
config_dict["safety_settings"] = self.safety_settings
# Add response format if specified
if self.response_format:
config_dict["response_modalities"] = ["TEXT"]
# Add system instruction if present
if system_instruction:
config_dict["system_instruction"] = system_instruction
# Add tools if present
if tools:
config_dict["tools"] = tools
return types.GenerateContentConfig(**config_dict)
def _handle_tool_calls(
self,
response,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Any:
"""Handle tool calls from Google Gen AI SDK response.
Args:
response: Google Gen AI SDK response
available_functions: Dict mapping function names to callables
from_task: Optional task context
from_agent: Optional agent context
Returns:
Result of function execution or error message
"""
# Check if response has function calls
if (
not available_functions
or not hasattr(response, "candidates")
or not response.candidates
):
return response.text if hasattr(response, "text") else str(response)
candidate = response.candidates[0] if response.candidates else None
if (
not candidate
or not hasattr(candidate, "content")
or not hasattr(candidate.content, "parts")
):
return response.text if hasattr(response, "text") else str(response)
# Look for function call parts
for part in candidate.content.parts:
if hasattr(part, "function_call"):
function_call = part.function_call
function_name = function_call.name
function_args = {}
if function_name not in available_functions:
return f"Error: Function '{function_name}' not found in available functions"
try:
# Google Gen AI SDK provides arguments as a struct
function_args = (
dict(function_call.args)
if hasattr(function_call, "args")
else {}
)
fn = available_functions[function_name]
# Execute function with event tracking
assert hasattr(crewai_event_bus, "emit")
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=function_name,
tool_args=function_args,
),
)
result = fn(**function_args)
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=function_name,
tool_args=function_args,
started_at=started_at,
finished_at=datetime.now(),
),
)
# Emit success event
event_data = {
"response": result,
"call_type": LLMCallType.TOOL_CALL,
"model": self.model,
}
if from_task is not None:
event_data["from_task"] = from_task
if from_agent is not None:
event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**event_data),
)
return result
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e}"
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=error_msg,
),
)
return error_msg
# If no function calls, return text content
return response.text if hasattr(response, "text") else str(response)
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
"""Call Google Gen AI SDK with the given messages.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional callbacks to execute
available_functions: Optional dict of available functions
from_task: Optional task context
from_agent: Optional agent context
Returns:
LLM response or tool execution result
Raises:
ValueError: If messages format is invalid
RuntimeError: If API call fails
"""
# Emit call started event
print("calling from native gemini", messages)
assert hasattr(crewai_event_bus, "emit")
# Prepare event data
started_event_data = {
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"model": self.model,
}
if from_task is not None:
started_event_data["from_task"] = from_task
if from_agent is not None:
started_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(**started_event_data),
)
retry_count = 0
last_error = None
while retry_count <= self.max_retries:
try:
# Format messages
formatted_messages = self._format_messages(messages)
system_instruction = getattr(self, "_system_instruction", None)
# Format tools if provided and supported
formatted_tools = self._format_tools(tools)
# Build generation config
generation_config = self._build_generation_config(
system_instruction, formatted_tools
)
# Execute callbacks before API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_start"):
callback.on_llm_start(
serialized={"name": self.__class__.__name__},
prompts=[str(formatted_messages)],
)
# Prepare the API call parameters
api_params = {
"model": self.model,
"contents": formatted_messages,
"config": generation_config,
}
# Make API call
if self.stream:
# Streaming response
response_stream = self.client.models.generate_content(**api_params)
full_response = ""
try:
for chunk in response_stream:
if hasattr(chunk, "text") and chunk.text:
full_response += chunk.text
except Exception as e:
print(
f"Streaming error (continuing with partial response): {e}"
)
result = full_response or "No response content"
else:
# Non-streaming response
response = self.client.models.generate_content(**api_params)
# Handle tool calls if present
result = self._handle_tool_calls(
response, available_functions, from_task, from_agent
)
# Execute callbacks after API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_end"):
callback.on_llm_end(response=result)
# Emit completion event
completion_event_data = {
"messages": messages, # Use original messages, not formatted_messages
"response": result,
"call_type": LLMCallType.LLM_CALL,
"model": self.model,
}
if from_task is not None:
completion_event_data["from_task"] = from_task
if from_agent is not None:
completion_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**completion_event_data),
)
return result
except Exception as e:
last_error = e
retry_count += 1
if retry_count <= self.max_retries:
print(
f"Gemini API call failed (attempt {retry_count}/{self.max_retries + 1}): {e}"
)
continue
# All retries exhausted
# Execute error callbacks
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_error"):
callback.on_llm_error(error=e)
# Emit failed event
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(error=str(e)),
)
raise RuntimeError(
f"Gemini API call failed after {self.max_retries + 1} attempts: {str(e)}"
) from e
def supports_stop_words(self) -> bool:
"""Check if Gemini models support stop words."""
return True
def get_context_window_size(self) -> int:
"""Get the context window size for the current model."""
if self.context_window_size != 0:
return self.context_window_size
# Use 85% of the context window like the original LLM class
context_window = self.model_config.get("context_window", 1048576)
self.context_window_size = int(context_window * 0.85)
return self.context_window_size
def supports_function_calling(self) -> bool:
"""Check if the current model supports function calling."""
return self.model_config.get("supports_tools", True)
def supports_vision(self) -> bool:
"""Check if the current model supports vision capabilities."""
return self.model_config.get("supports_vision", False)

View File

@@ -0,0 +1,5 @@
"""OpenAI LLM implementation for CrewAI."""
from .chat import OpenAILLM
__all__ = ["OpenAILLM"]

View File

@@ -0,0 +1,529 @@
import json
import os
from typing import Any, Dict, List, Optional, Union, Type, Literal
from openai import OpenAI
from pydantic import BaseModel
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.llm_events import (
LLMCallCompletedEvent,
LLMCallFailedEvent,
LLMCallStartedEvent,
LLMCallType,
)
from crewai.utilities.events.tool_usage_events import (
ToolUsageStartedEvent,
ToolUsageFinishedEvent,
ToolUsageErrorEvent,
)
from datetime import datetime
class OpenAILLM(BaseLLM):
"""OpenAI LLM implementation with full LLM class compatibility."""
def __init__(
self,
model: str = "gpt-4",
timeout: Optional[Union[float, int]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
n: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
presence_penalty: Optional[float] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Type[BaseModel]] = None,
seed: Optional[int] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_base: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None,
stream: bool = False,
max_retries: int = 2,
**kwargs,
):
"""Initialize OpenAI LLM with full compatibility.
Args:
model: OpenAI model name (e.g., 'gpt-4', 'gpt-3.5-turbo')
timeout: Request timeout in seconds
temperature: Sampling temperature (0-2)
top_p: Nucleus sampling parameter
n: Number of completions to generate
stop: Stop sequences
max_completion_tokens: Maximum tokens in completion
max_tokens: Maximum tokens (legacy parameter)
presence_penalty: Presence penalty (-2 to 2)
frequency_penalty: Frequency penalty (-2 to 2)
logit_bias: Logit bias dictionary
response_format: Pydantic model for structured output
seed: Random seed for deterministic output
logprobs: Whether to return log probabilities
top_logprobs: Number of most likely tokens to return
base_url: Custom API base URL
api_base: Legacy API base parameter
api_version: API version (for Azure)
api_key: OpenAI API key
callbacks: List of callback functions
reasoning_effort: Reasoning effort for o1 models
stream: Whether to stream responses
max_retries: Number of retries for failed requests
**kwargs: Additional parameters
"""
super().__init__(model=model, temperature=temperature)
# Store all parameters for compatibility
self.timeout = timeout
self.top_p = top_p
self.n = n
self.max_completion_tokens = max_completion_tokens
self.max_tokens = max_tokens or max_completion_tokens
self.presence_penalty = presence_penalty
self.frequency_penalty = frequency_penalty
self.logit_bias = logit_bias
self.response_format = response_format
self.seed = seed
self.logprobs = logprobs
self.top_logprobs = top_logprobs
self.api_base = api_base or base_url
self.base_url = base_url or api_base
self.api_version = api_version
self.api_key = api_key or os.getenv("OPENAI_API_KEY")
self.callbacks = callbacks
self.reasoning_effort = reasoning_effort
self.stream = stream
self.additional_params = kwargs
self.context_window_size = 0
# Normalize stop parameter to match LLM class behavior
if stop is None:
self.stop: List[str] = []
elif isinstance(stop, str):
self.stop = [stop]
else:
self.stop = stop
# Initialize OpenAI client
self.client = OpenAI(
api_key=self.api_key,
base_url=self.base_url,
timeout=self.timeout,
max_retries=max_retries,
**kwargs,
)
self.model_config = self._get_model_config()
def _get_model_config(self) -> Dict[str, Any]:
"""Get model-specific configuration."""
# Enhanced model configurations matching current LLM_CONTEXT_WINDOW_SIZES
model_configs = {
"gpt-4": {"context_window": 8192, "supports_tools": True},
"gpt-4o": {"context_window": 128000, "supports_tools": True},
"gpt-4o-mini": {"context_window": 200000, "supports_tools": True},
"gpt-4-turbo": {"context_window": 128000, "supports_tools": True},
"gpt-4.1": {"context_window": 1047576, "supports_tools": True},
"gpt-4.1-mini": {"context_window": 1047576, "supports_tools": True},
"gpt-4.1-nano": {"context_window": 1047576, "supports_tools": True},
"gpt-3.5-turbo": {"context_window": 16385, "supports_tools": True},
"o1-preview": {"context_window": 128000, "supports_tools": False},
"o1-mini": {"context_window": 128000, "supports_tools": False},
"o3-mini": {"context_window": 200000, "supports_tools": False},
"o4-mini": {"context_window": 200000, "supports_tools": False},
}
# Default config if model not found
default_config = {"context_window": 4096, "supports_tools": True}
for model_prefix, config in model_configs.items():
if self.model.startswith(model_prefix):
return config
return default_config
def _format_messages(
self, messages: Union[str, List[Dict[str, str]]]
) -> List[Dict[str, str]]:
"""Format messages for OpenAI API.
Args:
messages: Input messages as string or list of dicts
Returns:
List of properly formatted message dicts
"""
if isinstance(messages, str):
return [{"role": "user", "content": messages}]
# Validate message format
for msg in messages:
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
raise ValueError(
"Each message must be a dict with 'role' and 'content' keys"
)
# Handle O1 model special case (system messages not supported)
if "o1" in self.model.lower():
formatted_messages = []
for msg in messages:
if msg["role"] == "system":
# Convert system messages to assistant messages for O1
formatted_messages.append(
{"role": "assistant", "content": msg["content"]}
)
else:
formatted_messages.append(msg)
return formatted_messages
return messages
def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]:
"""Format tools for OpenAI function calling.
Args:
tools: List of tool definitions
Returns:
OpenAI-formatted tool definitions
"""
if not tools or not self.model_config.get("supports_tools", True):
return None
formatted_tools = []
for tool in tools:
# Convert to OpenAI tool format
formatted_tool = {
"type": "function",
"function": {
"name": tool.get("name", ""),
"description": tool.get("description", ""),
"parameters": tool.get("parameters", {}),
},
}
formatted_tools.append(formatted_tool)
return formatted_tools
def _handle_tool_calls(
self,
response,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Any:
"""Handle tool calls from OpenAI response.
Args:
response: OpenAI API response
available_functions: Dict mapping function names to callables
from_task: Optional task context
from_agent: Optional agent context
Returns:
Result of function execution or error message
"""
message = response.choices[0].message
if not message.tool_calls or not available_functions:
return message.content
# Execute the first tool call
tool_call = message.tool_calls[0]
function_name = tool_call.function.name
function_args = {}
if function_name not in available_functions:
return f"Error: Function '{function_name}' not found in available functions"
try:
# Parse function arguments
function_args = json.loads(tool_call.function.arguments)
fn = available_functions[function_name]
# Execute function with event tracking
assert hasattr(crewai_event_bus, "emit")
started_at = datetime.now()
crewai_event_bus.emit(
self,
event=ToolUsageStartedEvent(
tool_name=function_name,
tool_args=function_args,
),
)
result = fn(**function_args)
crewai_event_bus.emit(
self,
event=ToolUsageFinishedEvent(
output=result,
tool_name=function_name,
tool_args=function_args,
started_at=started_at,
finished_at=datetime.now(),
),
)
# Emit success event
event_data = {
"response": result,
"call_type": LLMCallType.TOOL_CALL,
"model": self.model,
}
if from_task is not None:
event_data["from_task"] = from_task
if from_agent is not None:
event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**event_data),
)
return result
except json.JSONDecodeError as e:
error_msg = f"Error parsing function arguments: {e}"
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=error_msg,
),
)
return error_msg
except Exception as e:
error_msg = f"Error executing function '{function_name}': {e}"
crewai_event_bus.emit(
self,
event=ToolUsageErrorEvent(
tool_name=function_name,
tool_args=function_args,
error=error_msg,
),
)
return error_msg
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
"""Call OpenAI API with the given messages.
Args:
messages: Input messages for the LLM
tools: Optional list of tool schemas
callbacks: Optional callbacks to execute
available_functions: Optional dict of available functions
from_task: Optional task context
from_agent: Optional agent context
Returns:
LLM response or tool execution result
Raises:
ValueError: If messages format is invalid
RuntimeError: If API call fails
"""
# Emit call started event
print("calling from native openai", messages)
assert hasattr(crewai_event_bus, "emit")
# Prepare event data
started_event_data = {
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"model": self.model,
}
if from_task is not None:
started_event_data["from_task"] = from_task
if from_agent is not None:
started_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallStartedEvent(**started_event_data),
)
try:
# Format messages
formatted_messages = self._format_messages(messages)
# Prepare API call parameters
api_params = {
"model": self.model,
"messages": formatted_messages,
}
# Add optional parameters
if self.temperature is not None:
api_params["temperature"] = self.temperature
if self.top_p is not None:
api_params["top_p"] = self.top_p
if self.n is not None:
api_params["n"] = self.n
if self.max_tokens is not None:
api_params["max_tokens"] = self.max_tokens
if self.presence_penalty is not None:
api_params["presence_penalty"] = self.presence_penalty
if self.frequency_penalty is not None:
api_params["frequency_penalty"] = self.frequency_penalty
if self.logit_bias is not None:
api_params["logit_bias"] = self.logit_bias
if self.seed is not None:
api_params["seed"] = self.seed
if self.logprobs is not None:
api_params["logprobs"] = self.logprobs
if self.top_logprobs is not None:
api_params["top_logprobs"] = self.top_logprobs
if self.stop:
api_params["stop"] = self.stop
if self.response_format is not None:
# Handle structured output for Pydantic models
if hasattr(self.response_format, "model_json_schema"):
api_params["response_format"] = {
"type": "json_schema",
"json_schema": {
"name": self.response_format.__name__,
"schema": self.response_format.model_json_schema(),
"strict": True,
},
}
else:
api_params["response_format"] = self.response_format
if self.reasoning_effort is not None and "o1" in self.model:
api_params["reasoning_effort"] = self.reasoning_effort
# Add tools if provided and supported
formatted_tools = self._format_tools(tools)
if formatted_tools:
api_params["tools"] = formatted_tools
api_params["tool_choice"] = "auto"
# Execute callbacks before API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_start"):
callback.on_llm_start(
serialized={"name": self.__class__.__name__},
prompts=[str(formatted_messages)],
)
# Make API call
if self.stream:
response = self.client.chat.completions.create(
stream=True, **api_params
)
# Handle streaming (simplified for now)
full_response = ""
for chunk in response:
if (
hasattr(chunk.choices[0].delta, "content")
and chunk.choices[0].delta.content
):
full_response += chunk.choices[0].delta.content
result = full_response
else:
response = self.client.chat.completions.create(**api_params)
# Handle tool calls if present
result = self._handle_tool_calls(
response, available_functions, from_task, from_agent
)
# If no tool calls, return text content
if result == response.choices[0].message.content:
result = response.choices[0].message.content or ""
# Execute callbacks after API call
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_end"):
callback.on_llm_end(response=result)
# Emit completion event
completion_event_data = {
"messages": formatted_messages,
"response": result,
"call_type": LLMCallType.LLM_CALL,
"model": self.model,
}
if from_task is not None:
completion_event_data["from_task"] = from_task
if from_agent is not None:
completion_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallCompletedEvent(**completion_event_data),
)
return result
except Exception as e:
# Execute error callbacks
if callbacks:
for callback in callbacks:
if hasattr(callback, "on_llm_error"):
callback.on_llm_error(error=e)
# Emit failed event
failed_event_data = {
"error": str(e),
}
if from_task is not None:
failed_event_data["from_task"] = from_task
if from_agent is not None:
failed_event_data["from_agent"] = from_agent
crewai_event_bus.emit(
self,
event=LLMCallFailedEvent(**failed_event_data),
)
raise RuntimeError(f"OpenAI API call failed: {str(e)}") from e
def supports_stop_words(self) -> bool:
"""Check if OpenAI models support stop words."""
return True
def get_context_window_size(self) -> int:
"""Get the context window size for the current model."""
if self.context_window_size != 0:
return self.context_window_size
# Use 85% of the context window like the original LLM class
context_window = self.model_config.get("context_window", 4096)
self.context_window_size = int(context_window * 0.85)
return self.context_window_size
def supports_function_calling(self) -> bool:
"""Check if the current model supports function calling."""
return self.model_config.get("supports_tools", True)

View File

@@ -7,39 +7,71 @@ from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
prefer_native: Optional[bool] = None,
) -> Optional[LLM | BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Now supports provider prefixes like 'openai/gpt-4' for native implementations.
Args:
llm_value (str | BaseLLM | Any | None):
- str: The model name (e.g., "gpt-4").
- str: The model name (e.g., "gpt-4" or "openai/gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
prefer_native (bool | None):
- True: Use native provider implementations when available
- False: Always use LiteLLM implementation
- None: Use environment variable CREWAI_PREFER_NATIVE_LLMS (default: True)
- Note: Provider prefixes (openai/, anthropic/) override this setting
Returns:
A BaseLLM instance if successful, or None if something fails.
Examples:
create_llm("gpt-4") # Uses LiteLLM or native based on prefer_native
create_llm("openai/gpt-4") # Always uses native OpenAI implementation
create_llm("anthropic/claude-3-sonnet") # Future: native Anthropic
"""
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
return llm_value
# 2) If llm_value is a string (model name)
# 2) Determine if we should prefer native implementations (unless provider prefix is used)
if prefer_native is None:
prefer_native = os.getenv("CREWAI_PREFER_NATIVE_LLMS", "true").lower() in (
"true",
"1",
"yes",
)
# 3) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
# Provider prefix (openai/, anthropic/) always takes precedence
if "/" in llm_value:
created_llm = LLM(model=llm_value) # LLM class handles routing
return created_llm
# Try native implementation first if preferred and no prefix
if prefer_native:
native_llm = _create_native_llm(llm_value)
if native_llm:
return native_llm
# Fallback to LiteLLM
created_llm = LLM(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None
# 3) If llm_value is None, parse environment variables or use default
# 4) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
return _llm_via_environment_or_fallback(prefer_native)
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
# 5) Otherwise, attempt to extract relevant attributes from an unknown object
try:
# Extract attributes with explicit types
model = (
@@ -48,6 +80,8 @@ def create_llm(
or getattr(llm_value, "deployment_name", None)
or str(llm_value)
)
# Extract other parameters
temperature: Optional[float] = getattr(llm_value, "temperature", None)
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
@@ -56,6 +90,7 @@ def create_llm(
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
# Use LLM class constructor which handles routing
created_llm = LLM(
model=model,
temperature=temperature,
@@ -72,9 +107,94 @@ def create_llm(
return None
def _llm_via_environment_or_fallback() -> Optional[LLM]:
def _create_native_llm(model: str, **kwargs) -> Optional[BaseLLM]:
"""
Create a native LLM implementation based on the model name.
Args:
model: The model name (e.g., 'gpt-4', 'claude-3-sonnet')
**kwargs: Additional parameters for the LLM
Returns:
Native LLM instance if supported, None otherwise
"""
try:
# OpenAI models
if _is_openai_model(model):
from crewai.llms.openai import OpenAILLM
return OpenAILLM(model=model, **kwargs)
# Claude models
if _is_claude_model(model):
from crewai.llms.anthropic import ClaudeLLM
return ClaudeLLM(model=model, **kwargs)
# Gemini models
if _is_gemini_model(model):
from crewai.llms.google import GeminiLLM
return GeminiLLM(model=model, **kwargs)
# No native implementation found
return None
except Exception as e:
print(f"Failed to create native LLM for model '{model}': {e}")
return None
def _is_openai_model(model: str) -> bool:
"""Check if a model is from OpenAI."""
openai_prefixes = (
"gpt-",
"text-davinci",
"text-curie",
"text-babbage",
"text-ada",
"davinci",
"curie",
"babbage",
"ada",
"o1-",
"o3-",
"o4-",
"chatgpt-",
)
model_lower = model.lower()
return any(model_lower.startswith(prefix) for prefix in openai_prefixes)
def _is_claude_model(model: str) -> bool:
"""Check if a model is from Anthropic (Claude)."""
claude_prefixes = (
"claude-",
"claude", # For cases like just "claude"
)
model_lower = model.lower()
return any(model_lower.startswith(prefix) for prefix in claude_prefixes)
def _is_gemini_model(model: str) -> bool:
"""Check if a model is from Google (Gemini)."""
gemini_prefixes = (
"gemini-",
"gemini", # For cases like just "gemini"
)
model_lower = model.lower()
return any(model_lower.startswith(prefix) for prefix in gemini_prefixes)
def _llm_via_environment_or_fallback(
prefer_native: bool = True,
) -> Optional[LLM | BaseLLM]:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
Now with native provider support.
"""
model_name = (
os.environ.get("MODEL")
@@ -83,7 +203,13 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]:
or DEFAULT_LLM_MODEL
)
# Initialize parameters with correct types
# Try native implementation first if preferred
if prefer_native:
native_llm = _create_native_llm(model_name)
if native_llm:
return native_llm
# Initialize parameters with correct types (original logic continues)
model: str = model_name
temperature: Optional[float] = None
max_tokens: Optional[int] = None