diff --git a/src/crewai/llm.py b/src/crewai/llm.py index c701ddf0b..0e9e7cd83 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -316,6 +316,143 @@ class LLM(BaseLLM): stream: bool = False, **kwargs, ): + # Check for provider prefixes and route to native implementations + if "/" in model: + provider, actual_model = model.split("/", 1) + + # Route to OpenAI native implementation + if provider.lower() == "openai": + try: + from crewai.llms.openai import OpenAILLM + + # Create native OpenAI instance with all the same parameters + native_llm = OpenAILLM( + model=actual_model, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + response_format=response_format, + seed=seed, + logprobs=logprobs, + top_logprobs=top_logprobs, + base_url=base_url, + api_base=api_base, + api_version=api_version, + api_key=api_key, + callbacks=callbacks, + reasoning_effort=reasoning_effort, + stream=stream, + **kwargs, + ) + + # Replace this LLM instance with the native one + self.__class__ = native_llm.__class__ + self.__dict__.update(native_llm.__dict__) + return + + except ImportError: + # Fall back to LiteLLM if native implementation unavailable + print( + f"Native OpenAI implementation not available, using LiteLLM for {model}" + ) + model = actual_model # Remove the prefix for LiteLLM + + # Route to Claude native implementation + elif provider.lower() == "anthropic": + try: + from crewai.llms.anthropic import ClaudeLLM + + # Create native Claude instance with all the same parameters + native_llm = ClaudeLLM( + model=actual_model, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + response_format=response_format, + seed=seed, + logprobs=logprobs, + top_logprobs=top_logprobs, + base_url=base_url, + api_base=api_base, + api_version=api_version, + api_key=api_key, + callbacks=callbacks, + reasoning_effort=reasoning_effort, + stream=stream, + **kwargs, + ) + + # Replace this LLM instance with the native one + self.__class__ = native_llm.__class__ + self.__dict__.update(native_llm.__dict__) + return + + except ImportError: + # Fall back to LiteLLM if native implementation unavailable + print( + f"Native Claude implementation not available, using LiteLLM for {model}" + ) + model = actual_model # Remove the prefix for LiteLLM + + # Route to Gemini native implementation + elif provider.lower() == "google": + try: + from crewai.llms.google import GeminiLLM + + # Create native Gemini instance with all the same parameters + native_llm = GeminiLLM( + model=actual_model, + timeout=timeout, + temperature=temperature, + top_p=top_p, + n=n, + stop=stop, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + presence_penalty=presence_penalty, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + response_format=response_format, + seed=seed, + logprobs=logprobs, + top_logprobs=top_logprobs, + base_url=base_url, + api_base=api_base, + api_version=api_version, + api_key=api_key, + callbacks=callbacks, + reasoning_effort=reasoning_effort, + stream=stream, + **kwargs, + ) + + # Replace this LLM instance with the native one + self.__class__ = native_llm.__class__ + self.__dict__.update(native_llm.__dict__) + return + + except ImportError: + # Fall back to LiteLLM if native implementation unavailable + print( + f"Native Gemini implementation not available, using LiteLLM for {model}" + ) + model = actual_model # Remove the prefix for LiteLLM + + # Continue with original LiteLLM initialization self.model = model self.timeout = timeout self.temperature = temperature @@ -1139,7 +1276,11 @@ class LLM(BaseLLM): # TODO: Remove this code after merging PR https://github.com/BerriAI/litellm/pull/10917 # Ollama doesn't supports last message to be 'assistant' - if "ollama" in self.model.lower() and messages and messages[-1]["role"] == "assistant": + if ( + "ollama" in self.model.lower() + and messages + and messages[-1]["role"] == "assistant" + ): return messages + [{"role": "user", "content": ""}] # Handle Anthropic models diff --git a/src/crewai/llms/__init__.py b/src/crewai/llms/__init__.py index fda1e6a3b..8529dec34 100644 --- a/src/crewai/llms/__init__.py +++ b/src/crewai/llms/__init__.py @@ -1 +1,11 @@ -"""LLM implementations for crewAI.""" +"""CrewAI LLM implementations.""" + +from .base_llm import BaseLLM +from .openai import OpenAILLM +from .anthropic import ClaudeLLM +from .google import GeminiLLM + +# Import the main LLM class for backward compatibility + + +__all__ = ["BaseLLM", "OpenAILLM", "ClaudeLLM", "GeminiLLM"] diff --git a/src/crewai/llms/anthropic/__init__.py b/src/crewai/llms/anthropic/__init__.py new file mode 100644 index 000000000..5cf867919 --- /dev/null +++ b/src/crewai/llms/anthropic/__init__.py @@ -0,0 +1,5 @@ +"""Anthropic Claude LLM implementation for CrewAI.""" + +from .claude import ClaudeLLM + +__all__ = ["ClaudeLLM"] diff --git a/src/crewai/llms/anthropic/claude.py b/src/crewai/llms/anthropic/claude.py new file mode 100644 index 000000000..45d6504d1 --- /dev/null +++ b/src/crewai/llms/anthropic/claude.py @@ -0,0 +1,569 @@ +import os +from typing import Any, Dict, List, Optional, Union, Type, Literal +from anthropic import Anthropic +from pydantic import BaseModel + +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.events import crewai_event_bus +from crewai.utilities.events.llm_events import ( + LLMCallCompletedEvent, + LLMCallFailedEvent, + LLMCallStartedEvent, + LLMCallType, +) +from crewai.utilities.events.tool_usage_events import ( + ToolUsageStartedEvent, + ToolUsageFinishedEvent, + ToolUsageErrorEvent, +) +from datetime import datetime + + +class ClaudeLLM(BaseLLM): + """Anthropic Claude LLM implementation with full LLM class compatibility.""" + + def __init__( + self, + model: str = "claude-3-5-sonnet-20241022", + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, # Not supported by Claude but kept for compatibility + stop: Optional[Union[str, List[str]]] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[ + float + ] = None, # Not supported but kept for compatibility + frequency_penalty: Optional[ + float + ] = None, # Not supported but kept for compatibility + logit_bias: Optional[ + Dict[int, float] + ] = None, # Not supported but kept for compatibility + response_format: Optional[Type[BaseModel]] = None, + seed: Optional[int] = None, # Not supported but kept for compatibility + logprobs: Optional[int] = None, # Not supported but kept for compatibility + top_logprobs: Optional[int] = None, # Not supported but kept for compatibility + base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, # Not used by Anthropic + api_key: Optional[str] = None, + callbacks: List[Any] = [], + reasoning_effort: Optional[ + Literal["none", "low", "medium", "high"] + ] = None, # Not used by Claude + stream: bool = False, + max_retries: int = 2, + # Claude-specific parameters + thinking_mode: bool = False, # Enable Claude's thinking mode + top_k: Optional[int] = None, # Claude-specific sampling parameter + **kwargs, + ): + """Initialize Claude LLM with full compatibility. + + Args: + model: Claude model name (e.g., 'claude-3-5-sonnet-20241022') + timeout: Request timeout in seconds + temperature: Sampling temperature (0-1 for Claude) + top_p: Nucleus sampling parameter + n: Number of completions (not supported by Claude, kept for compatibility) + stop: Stop sequences + max_completion_tokens: Maximum tokens in completion + max_tokens: Maximum tokens (legacy parameter) + presence_penalty: Not supported by Claude, kept for compatibility + frequency_penalty: Not supported by Claude, kept for compatibility + logit_bias: Not supported by Claude, kept for compatibility + response_format: Pydantic model for structured output + seed: Not supported by Claude, kept for compatibility + logprobs: Not supported by Claude, kept for compatibility + top_logprobs: Not supported by Claude, kept for compatibility + base_url: Custom API base URL + api_base: Legacy API base parameter + api_version: Not used by Anthropic + api_key: Anthropic API key + callbacks: List of callback functions + reasoning_effort: Not used by Claude, kept for compatibility + stream: Whether to stream responses + max_retries: Number of retries for failed requests + thinking_mode: Enable Claude's thinking mode (if supported) + top_k: Claude-specific top-k sampling parameter + **kwargs: Additional parameters + """ + super().__init__(model=model, temperature=temperature) + + # Store all parameters for compatibility + self.timeout = timeout + self.top_p = top_p + self.n = n # Claude doesn't support n>1, but we store it for compatibility + self.max_completion_tokens = max_completion_tokens + self.max_tokens = max_tokens or max_completion_tokens + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias + self.response_format = response_format + self.seed = seed + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.api_base = api_base or base_url + self.base_url = base_url or api_base + self.api_version = api_version + self.api_key = api_key or os.getenv("ANTHROPIC_API_KEY") + self.callbacks = callbacks + self.reasoning_effort = reasoning_effort + self.stream = stream + self.additional_params = kwargs + self.context_window_size = 0 + + # Claude-specific parameters + self.thinking_mode = thinking_mode + self.top_k = top_k + + # Normalize stop parameter to match LLM class behavior + if stop is None: + self.stop: List[str] = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = stop + + # Initialize Anthropic client + client_kwargs = {} + if self.api_key: + client_kwargs["api_key"] = self.api_key + if self.base_url: + client_kwargs["base_url"] = self.base_url + if self.timeout: + client_kwargs["timeout"] = self.timeout + if max_retries: + client_kwargs["max_retries"] = max_retries + + # Add any additional kwargs that might be relevant to the client + for key, value in kwargs.items(): + if key not in ["thinking_mode", "top_k"]: # Exclude our custom params + client_kwargs[key] = value + + self.client = Anthropic(**client_kwargs) + self.model_config = self._get_model_config() + + def _get_model_config(self) -> Dict[str, Any]: + """Get model-specific configuration for Claude models.""" + # Claude model configurations based on Anthropic's documentation + model_configs = { + # Claude 3.5 Sonnet + "claude-3-5-sonnet-20241022": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + "claude-3-5-sonnet-20240620": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + # Claude 3.5 Haiku + "claude-3-5-haiku-20241022": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + # Claude 3 Opus + "claude-3-opus-20240229": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + # Claude 3 Sonnet + "claude-3-sonnet-20240229": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + # Claude 3 Haiku + "claude-3-haiku-20240307": { + "context_window": 200000, + "supports_tools": True, + "supports_vision": True, + }, + # Claude 2.1 + "claude-2.1": { + "context_window": 200000, + "supports_tools": False, + "supports_vision": False, + }, + "claude-2": { + "context_window": 100000, + "supports_tools": False, + "supports_vision": False, + }, + # Claude Instant + "claude-instant-1.2": { + "context_window": 100000, + "supports_tools": False, + "supports_vision": False, + }, + } + + # Default config if model not found + default_config = { + "context_window": 200000, + "supports_tools": True, + "supports_vision": False, + } + + # Try exact match first + if self.model in model_configs: + return model_configs[self.model] + + # Try prefix match for versioned models + for model_prefix, config in model_configs.items(): + if self.model.startswith(model_prefix): + return config + + return default_config + + def _format_messages( + self, messages: Union[str, List[Dict[str, str]]] + ) -> List[Dict[str, str]]: + """Format messages for Anthropic API. + + Args: + messages: Input messages as string or list of dicts + + Returns: + List of properly formatted message dicts + """ + if isinstance(messages, str): + return [{"role": "user", "content": messages}] + + # Validate message format + for msg in messages: + if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + raise ValueError( + "Each message must be a dict with 'role' and 'content' keys" + ) + + # Claude requires alternating user/assistant messages and cannot start with system + formatted_messages = [] + system_message = None + + for msg in messages: + if msg["role"] == "system": + # Store system message separately - Claude handles it differently + if system_message is None: + system_message = msg["content"] + else: + system_message += "\n\n" + msg["content"] + else: + formatted_messages.append(msg) + + # Ensure messages alternate and start with user + if formatted_messages and formatted_messages[0]["role"] != "user": + formatted_messages.insert(0, {"role": "user", "content": "Hello"}) + + # Store system message for later use + self._system_message = system_message + + return formatted_messages + + def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]: + """Format tools for Claude function calling. + + Args: + tools: List of tool definitions + + Returns: + Claude-formatted tool definitions + """ + if not tools or not self.model_config.get("supports_tools", True): + return None + + formatted_tools = [] + for tool in tools: + # Convert to Claude tool format + formatted_tool = { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "input_schema": tool.get("parameters", {}), + } + formatted_tools.append(formatted_tool) + + return formatted_tools + + def _handle_tool_calls( + self, + response, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Any: + """Handle tool calls from Claude response. + + Args: + response: Claude API response + available_functions: Dict mapping function names to callables + from_task: Optional task context + from_agent: Optional agent context + + Returns: + Result of function execution or error message + """ + # Claude returns tool use in content blocks + if not hasattr(response, "content") or not available_functions: + return response.content[0].text if response.content else "" + + # Look for tool use blocks + for content_block in response.content: + if hasattr(content_block, "type") and content_block.type == "tool_use": + function_name = content_block.name + function_args = {} + + if function_name not in available_functions: + return f"Error: Function '{function_name}' not found in available functions" + + try: + # Claude provides arguments as a dict + function_args = content_block.input + fn = available_functions[function_name] + + # Execute function with event tracking + assert hasattr(crewai_event_bus, "emit") + started_at = datetime.now() + crewai_event_bus.emit( + self, + event=ToolUsageStartedEvent( + tool_name=function_name, + tool_args=function_args, + ), + ) + + result = fn(**function_args) + + crewai_event_bus.emit( + self, + event=ToolUsageFinishedEvent( + output=result, + tool_name=function_name, + tool_args=function_args, + started_at=started_at, + finished_at=datetime.now(), + ), + ) + + # Emit success event + event_data = { + "response": result, + "call_type": LLMCallType.TOOL_CALL, + "model": self.model, + } + if from_task is not None: + event_data["from_task"] = from_task + if from_agent is not None: + event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**event_data), + ) + + return result + + except Exception as e: + error_msg = f"Error executing function '{function_name}': {e}" + crewai_event_bus.emit( + self, + event=ToolUsageErrorEvent( + tool_name=function_name, + tool_args=function_args, + error=error_msg, + ), + ) + return error_msg + + # If no tool calls, return text content + return response.content[0].text if response.content else "" + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Union[str, Any]: + """Call Claude API with the given messages. + + Args: + messages: Input messages for the LLM + tools: Optional list of tool schemas + callbacks: Optional callbacks to execute + available_functions: Optional dict of available functions + from_task: Optional task context + from_agent: Optional agent context + + Returns: + LLM response or tool execution result + + Raises: + ValueError: If messages format is invalid + RuntimeError: If API call fails + """ + # Emit call started event + print("calling from native claude", messages) + assert hasattr(crewai_event_bus, "emit") + + # Prepare event data + started_event_data = { + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "model": self.model, + } + if from_task is not None: + started_event_data["from_task"] = from_task + if from_agent is not None: + started_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallStartedEvent(**started_event_data), + ) + + try: + # Format messages + formatted_messages = self._format_messages(messages) + system_message = getattr(self, "_system_message", None) + + # Prepare API call parameters + api_params = { + "model": self.model, + "messages": formatted_messages, + "max_tokens": self.max_tokens or 4000, # Claude requires max_tokens + } + + # Add system message if present + if system_message: + api_params["system"] = system_message + + # Add optional parameters that Claude supports + if self.temperature is not None: + api_params["temperature"] = self.temperature + + if self.top_p is not None: + api_params["top_p"] = self.top_p + + if self.top_k is not None: + api_params["top_k"] = self.top_k + + if self.stop: + api_params["stop_sequences"] = self.stop + + # Add tools if provided and supported + formatted_tools = self._format_tools(tools) + if formatted_tools: + api_params["tools"] = formatted_tools + + # Execute callbacks before API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_start"): + callback.on_llm_start( + serialized={"name": self.__class__.__name__}, + prompts=[str(formatted_messages)], + ) + + # Make API call + if self.stream: + response = self.client.messages.create(stream=True, **api_params) + # Handle streaming (simplified implementation) + full_response = "" + try: + for event in response: + if hasattr(event, "type"): + if event.type == "content_block_delta": + if hasattr(event, "delta") and hasattr( + event.delta, "text" + ): + full_response += event.delta.text + except Exception as e: + # If streaming fails, fall back to the response we have + print(f"Streaming error (continuing with partial response): {e}") + result = full_response or "No response content" + else: + response = self.client.messages.create(**api_params) + # Handle tool calls if present + result = self._handle_tool_calls( + response, available_functions, from_task, from_agent + ) + + # Execute callbacks after API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_end"): + callback.on_llm_end(response=result) + + # Emit completion event + completion_event_data = { + "messages": formatted_messages, + "response": result, + "call_type": LLMCallType.LLM_CALL, + "model": self.model, + } + if from_task is not None: + completion_event_data["from_task"] = from_task + if from_agent is not None: + completion_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**completion_event_data), + ) + + return result + + except Exception as e: + # Execute error callbacks + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_error"): + callback.on_llm_error(error=e) + + # Emit failed event + failed_event_data = { + "error": str(e), + } + if from_task is not None: + failed_event_data["from_task"] = from_task + if from_agent is not None: + failed_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent(**failed_event_data), + ) + + raise RuntimeError(f"Claude API call failed: {str(e)}") from e + + def supports_stop_words(self) -> bool: + """Check if Claude models support stop words.""" + return True + + def get_context_window_size(self) -> int: + """Get the context window size for the current model.""" + if self.context_window_size != 0: + return self.context_window_size + + # Use 85% of the context window like the original LLM class + context_window = self.model_config.get("context_window", 200000) + self.context_window_size = int(context_window * 0.85) + return self.context_window_size + + def supports_function_calling(self) -> bool: + """Check if the current model supports function calling.""" + return self.model_config.get("supports_tools", True) + + def supports_vision(self) -> bool: + """Check if the current model supports vision capabilities.""" + return self.model_config.get("supports_vision", False) diff --git a/src/crewai/llms/google/__init__.py b/src/crewai/llms/google/__init__.py new file mode 100644 index 000000000..b913c1949 --- /dev/null +++ b/src/crewai/llms/google/__init__.py @@ -0,0 +1,5 @@ +"""Google Gemini LLM implementation for CrewAI.""" + +from .gemini import GeminiLLM + +__all__ = ["GeminiLLM"] diff --git a/src/crewai/llms/google/gemini.py b/src/crewai/llms/google/gemini.py new file mode 100644 index 000000000..96de2973e --- /dev/null +++ b/src/crewai/llms/google/gemini.py @@ -0,0 +1,737 @@ +import os +from typing import Any, Dict, List, Optional, Union, Type, Literal, TYPE_CHECKING +from pydantic import BaseModel + +if TYPE_CHECKING: + from google import genai + from google.genai import types + +try: + from google import genai + from google.genai import types +except ImportError: + genai = None + types = None + +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.events import crewai_event_bus +from crewai.utilities.events.llm_events import ( + LLMCallCompletedEvent, + LLMCallFailedEvent, + LLMCallStartedEvent, + LLMCallType, +) +from crewai.utilities.events.tool_usage_events import ( + ToolUsageStartedEvent, + ToolUsageFinishedEvent, + ToolUsageErrorEvent, +) +from datetime import datetime + + +class GeminiLLM(BaseLLM): + """Google Gemini LLM implementation using the official Google Gen AI Python SDK.""" + + def __init__( + self, + model: str = "gemini-1.5-pro", + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, # Not supported by Gemini but kept for compatibility + stop: Optional[Union[str, List[str]]] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[ + float + ] = None, # Not supported but kept for compatibility + frequency_penalty: Optional[ + float + ] = None, # Not supported but kept for compatibility + logit_bias: Optional[ + Dict[int, float] + ] = None, # Not supported but kept for compatibility + response_format: Optional[Type[BaseModel]] = None, + seed: Optional[int] = None, # Not supported but kept for compatibility + logprobs: Optional[int] = None, # Not supported but kept for compatibility + top_logprobs: Optional[int] = None, # Not supported but kept for compatibility + base_url: Optional[str] = None, # Not used by Gemini + api_base: Optional[str] = None, # Not used by Gemini + api_version: Optional[str] = None, # Not used by Gemini + api_key: Optional[str] = None, + callbacks: List[Any] = [], + reasoning_effort: Optional[ + Literal["none", "low", "medium", "high"] + ] = None, # Not used by Gemini + stream: bool = False, + max_retries: int = 2, + # Gemini-specific parameters + top_k: Optional[int] = None, # Gemini top-k sampling parameter + candidate_count: int = 1, # Number of response candidates + safety_settings: Optional[ + List[Dict[str, Any]] + ] = None, # Gemini safety settings + generation_config: Optional[ + Dict[str, Any] + ] = None, # Additional generation config + # Vertex AI parameters + use_vertex_ai: bool = False, + project_id: Optional[str] = None, + location: str = "us-central1", + **kwargs, + ): + """Initialize Gemini LLM with the official Google Gen AI SDK. + + Args: + model: Gemini model name (e.g., 'gemini-1.5-pro', 'gemini-2.0-flash-001') + timeout: Request timeout in seconds + temperature: Sampling temperature (0-2 for Gemini) + top_p: Nucleus sampling parameter + n: Number of completions (not supported by Gemini, kept for compatibility) + stop: Stop sequences + max_completion_tokens: Maximum tokens in completion + max_tokens: Maximum tokens (legacy parameter) + presence_penalty: Not supported by Gemini, kept for compatibility + frequency_penalty: Not supported by Gemini, kept for compatibility + logit_bias: Not supported by Gemini, kept for compatibility + response_format: Pydantic model for structured output + seed: Not supported by Gemini, kept for compatibility + logprobs: Not supported by Gemini, kept for compatibility + top_logprobs: Not supported by Gemini, kept for compatibility + base_url: Not used by Gemini + api_base: Not used by Gemini + api_version: Not used by Gemini + api_key: Google AI API key + callbacks: List of callback functions + reasoning_effort: Not used by Gemini, kept for compatibility + stream: Whether to stream responses + max_retries: Number of retries for failed requests + top_k: Gemini-specific top-k sampling parameter + candidate_count: Number of response candidates to generate + safety_settings: Gemini safety settings configuration + generation_config: Additional Gemini generation configuration + use_vertex_ai: Whether to use Vertex AI instead of Gemini API + project_id: Google Cloud project ID (required for Vertex AI) + location: Google Cloud region (default: us-central1) + **kwargs: Additional parameters + """ + # Check if Google Gen AI SDK is available + if genai is None or types is None: + raise ImportError( + "Google Gen AI Python SDK is required. Please install it with: " + "pip install google-genai" + ) + + super().__init__(model=model, temperature=temperature) + + # Store all parameters for compatibility + self.timeout = timeout + self.top_p = top_p + self.n = n + self.max_completion_tokens = max_completion_tokens + self.max_tokens = max_tokens or max_completion_tokens + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias + self.response_format = response_format + self.seed = seed + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.api_base = api_base + self.base_url = base_url + self.api_version = api_version + self.callbacks = callbacks + self.reasoning_effort = reasoning_effort + self.stream = stream + self.additional_params = kwargs + self.context_window_size = 0 + self.max_retries = max_retries + + # Gemini-specific parameters + self.top_k = top_k + self.candidate_count = candidate_count + self.safety_settings = safety_settings or [] + self.generation_config = generation_config or {} + + # Vertex AI parameters + self.use_vertex_ai = use_vertex_ai + self.project_id = project_id or os.getenv("GOOGLE_CLOUD_PROJECT") + self.location = location + + # API key handling + self.api_key = ( + api_key + or os.getenv("GOOGLE_AI_API_KEY") + or os.getenv("GEMINI_API_KEY") + or os.getenv("GOOGLE_API_KEY") + ) + + # Normalize stop parameter to match LLM class behavior + if stop is None: + self.stop: List[str] = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = stop + + # Initialize client attribute + self.client: Any = None + + # Initialize the Google Gen AI client + self._initialize_client() + self.model_config = self._get_model_config() + + def _initialize_client(self): + """Initialize the Google Gen AI client.""" + if genai is None or types is None: + return + + try: + if self.use_vertex_ai: + if not self.project_id: + raise ValueError( + "project_id is required when use_vertex_ai=True. " + "Set it directly or via GOOGLE_CLOUD_PROJECT environment variable." + ) + self.client = genai.Client( + vertexai=True, + project=self.project_id, + location=self.location, + ) + else: + if not self.api_key: + raise ValueError( + "API key is required for Gemini Developer API. " + "Set it via api_key parameter or GOOGLE_AI_API_KEY/GEMINI_API_KEY environment variable." + ) + self.client = genai.Client(api_key=self.api_key) + except Exception as e: + raise RuntimeError( + f"Failed to initialize Google Gen AI client: {str(e)}" + ) from e + + def _get_model_config(self) -> Dict[str, Any]: + """Get model-specific configuration for Gemini models.""" + # Gemini model configurations based on Google's documentation + model_configs = { + # Gemini 2.0 Flash (latest) + "gemini-2.0-flash": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-2.0-flash-001": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-2.0-flash-exp": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + # Gemini 1.5 Pro + "gemini-1.5-pro": { + "context_window": 2097152, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-pro-002": { + "context_window": 2097152, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-pro-001": { + "context_window": 2097152, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-pro-exp-0827": { + "context_window": 2097152, + "supports_tools": True, + "supports_vision": True, + }, + # Gemini 1.5 Flash + "gemini-1.5-flash": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-flash-002": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-flash-001": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-flash-8b": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + "gemini-1.5-flash-8b-exp-0827": { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + }, + # Legacy Gemini Pro + "gemini-pro": { + "context_window": 30720, + "supports_tools": True, + "supports_vision": False, + }, + "gemini-pro-vision": { + "context_window": 16384, + "supports_tools": False, + "supports_vision": True, + }, + # Gemini Ultra (when available) + "gemini-ultra": { + "context_window": 30720, + "supports_tools": True, + "supports_vision": True, + }, + } + + # Default config if model not found + default_config = { + "context_window": 1048576, + "supports_tools": True, + "supports_vision": True, + } + + # Try exact match first + if self.model in model_configs: + return model_configs[self.model] + + # Try prefix match for versioned models + for model_prefix, config in model_configs.items(): + if self.model.startswith(model_prefix): + return config + + return default_config + + def _format_messages(self, messages: Union[str, List[Dict[str, str]]]) -> List[Any]: + """Format messages for Google Gen AI SDK. + + Args: + messages: Input messages as string or list of dicts + + Returns: + List of properly formatted Content objects + """ + if genai is None or types is None: + return [] + + if isinstance(messages, str): + return [ + types.Content(role="user", parts=[types.Part.from_text(text=messages)]) + ] + + # Validate message format + for msg in messages: + if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + raise ValueError( + "Each message must be a dict with 'role' and 'content' keys" + ) + + # Convert to Google Gen AI SDK format + formatted_messages = [] + system_instruction = None + + for msg in messages: + role = msg["role"] + content = msg["content"] + + if role == "system": + # System instruction will be handled separately + system_instruction = content + elif role == "user": + formatted_messages.append( + types.Content( + role="user", parts=[types.Part.from_text(text=content)] + ) + ) + elif role == "assistant": + formatted_messages.append( + types.Content( + role="model", parts=[types.Part.from_text(text=content)] + ) + ) + + # Store system instruction for later use + self._system_instruction = system_instruction + + return formatted_messages + + def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[Any]]: + """Format tools for Google Gen AI SDK function calling. + + Args: + tools: List of tool definitions + + Returns: + Google Gen AI SDK formatted tool definitions + """ + if genai is None or types is None: + return None + + if not tools or not self.model_config.get("supports_tools", True): + return None + + formatted_tools = [] + for tool in tools: + # Convert to Google Gen AI SDK function declaration format + function_declaration = types.FunctionDeclaration( + name=tool.get("name", ""), + description=tool.get("description", ""), + parameters=tool.get("parameters", {}), + ) + formatted_tools.append( + types.Tool(function_declarations=[function_declaration]) + ) + + return formatted_tools + + def _build_generation_config( + self, + system_instruction: Optional[str] = None, + tools: Optional[List[Any]] = None, + ) -> Any: + """Build Google Gen AI SDK generation config from parameters.""" + if genai is None or types is None: + return {} + config_dict = self.generation_config.copy() + + # Add parameters that map to Gemini's generation config + if self.temperature is not None: + config_dict["temperature"] = self.temperature + + if self.top_p is not None: + config_dict["top_p"] = self.top_p + + if self.top_k is not None: + config_dict["top_k"] = self.top_k + + if self.max_tokens is not None: + config_dict["max_output_tokens"] = self.max_tokens + + if self.candidate_count is not None: + config_dict["candidate_count"] = self.candidate_count + + if self.stop: + config_dict["stop_sequences"] = self.stop + + if self.stream: + config_dict["stream"] = True + + # Add safety settings + if self.safety_settings: + config_dict["safety_settings"] = self.safety_settings + + # Add response format if specified + if self.response_format: + config_dict["response_modalities"] = ["TEXT"] + + # Add system instruction if present + if system_instruction: + config_dict["system_instruction"] = system_instruction + + # Add tools if present + if tools: + config_dict["tools"] = tools + + return types.GenerateContentConfig(**config_dict) + + def _handle_tool_calls( + self, + response, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Any: + """Handle tool calls from Google Gen AI SDK response. + + Args: + response: Google Gen AI SDK response + available_functions: Dict mapping function names to callables + from_task: Optional task context + from_agent: Optional agent context + + Returns: + Result of function execution or error message + """ + # Check if response has function calls + if ( + not available_functions + or not hasattr(response, "candidates") + or not response.candidates + ): + return response.text if hasattr(response, "text") else str(response) + + candidate = response.candidates[0] if response.candidates else None + if ( + not candidate + or not hasattr(candidate, "content") + or not hasattr(candidate.content, "parts") + ): + return response.text if hasattr(response, "text") else str(response) + + # Look for function call parts + for part in candidate.content.parts: + if hasattr(part, "function_call"): + function_call = part.function_call + function_name = function_call.name + function_args = {} + + if function_name not in available_functions: + return f"Error: Function '{function_name}' not found in available functions" + + try: + # Google Gen AI SDK provides arguments as a struct + function_args = ( + dict(function_call.args) + if hasattr(function_call, "args") + else {} + ) + fn = available_functions[function_name] + + # Execute function with event tracking + assert hasattr(crewai_event_bus, "emit") + started_at = datetime.now() + crewai_event_bus.emit( + self, + event=ToolUsageStartedEvent( + tool_name=function_name, + tool_args=function_args, + ), + ) + + result = fn(**function_args) + + crewai_event_bus.emit( + self, + event=ToolUsageFinishedEvent( + output=result, + tool_name=function_name, + tool_args=function_args, + started_at=started_at, + finished_at=datetime.now(), + ), + ) + + # Emit success event + event_data = { + "response": result, + "call_type": LLMCallType.TOOL_CALL, + "model": self.model, + } + if from_task is not None: + event_data["from_task"] = from_task + if from_agent is not None: + event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**event_data), + ) + + return result + + except Exception as e: + error_msg = f"Error executing function '{function_name}': {e}" + crewai_event_bus.emit( + self, + event=ToolUsageErrorEvent( + tool_name=function_name, + tool_args=function_args, + error=error_msg, + ), + ) + return error_msg + + # If no function calls, return text content + return response.text if hasattr(response, "text") else str(response) + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Union[str, Any]: + """Call Google Gen AI SDK with the given messages. + + Args: + messages: Input messages for the LLM + tools: Optional list of tool schemas + callbacks: Optional callbacks to execute + available_functions: Optional dict of available functions + from_task: Optional task context + from_agent: Optional agent context + + Returns: + LLM response or tool execution result + + Raises: + ValueError: If messages format is invalid + RuntimeError: If API call fails + """ + # Emit call started event + print("calling from native gemini", messages) + assert hasattr(crewai_event_bus, "emit") + + # Prepare event data + started_event_data = { + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "model": self.model, + } + if from_task is not None: + started_event_data["from_task"] = from_task + if from_agent is not None: + started_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallStartedEvent(**started_event_data), + ) + + retry_count = 0 + last_error = None + + while retry_count <= self.max_retries: + try: + # Format messages + formatted_messages = self._format_messages(messages) + system_instruction = getattr(self, "_system_instruction", None) + + # Format tools if provided and supported + formatted_tools = self._format_tools(tools) + + # Build generation config + generation_config = self._build_generation_config( + system_instruction, formatted_tools + ) + + # Execute callbacks before API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_start"): + callback.on_llm_start( + serialized={"name": self.__class__.__name__}, + prompts=[str(formatted_messages)], + ) + + # Prepare the API call parameters + api_params = { + "model": self.model, + "contents": formatted_messages, + "config": generation_config, + } + + # Make API call + if self.stream: + # Streaming response + response_stream = self.client.models.generate_content(**api_params) + + full_response = "" + try: + for chunk in response_stream: + if hasattr(chunk, "text") and chunk.text: + full_response += chunk.text + except Exception as e: + print( + f"Streaming error (continuing with partial response): {e}" + ) + + result = full_response or "No response content" + else: + # Non-streaming response + response = self.client.models.generate_content(**api_params) + + # Handle tool calls if present + result = self._handle_tool_calls( + response, available_functions, from_task, from_agent + ) + + # Execute callbacks after API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_end"): + callback.on_llm_end(response=result) + + # Emit completion event + completion_event_data = { + "messages": messages, # Use original messages, not formatted_messages + "response": result, + "call_type": LLMCallType.LLM_CALL, + "model": self.model, + } + if from_task is not None: + completion_event_data["from_task"] = from_task + if from_agent is not None: + completion_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**completion_event_data), + ) + + return result + + except Exception as e: + last_error = e + retry_count += 1 + + if retry_count <= self.max_retries: + print( + f"Gemini API call failed (attempt {retry_count}/{self.max_retries + 1}): {e}" + ) + continue + + # All retries exhausted + # Execute error callbacks + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_error"): + callback.on_llm_error(error=e) + + # Emit failed event + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent(error=str(e)), + ) + + raise RuntimeError( + f"Gemini API call failed after {self.max_retries + 1} attempts: {str(e)}" + ) from e + + def supports_stop_words(self) -> bool: + """Check if Gemini models support stop words.""" + return True + + def get_context_window_size(self) -> int: + """Get the context window size for the current model.""" + if self.context_window_size != 0: + return self.context_window_size + + # Use 85% of the context window like the original LLM class + context_window = self.model_config.get("context_window", 1048576) + self.context_window_size = int(context_window * 0.85) + return self.context_window_size + + def supports_function_calling(self) -> bool: + """Check if the current model supports function calling.""" + return self.model_config.get("supports_tools", True) + + def supports_vision(self) -> bool: + """Check if the current model supports vision capabilities.""" + return self.model_config.get("supports_vision", False) diff --git a/src/crewai/llms/openai/__init__.py b/src/crewai/llms/openai/__init__.py new file mode 100644 index 000000000..e9978cf9c --- /dev/null +++ b/src/crewai/llms/openai/__init__.py @@ -0,0 +1,5 @@ +"""OpenAI LLM implementation for CrewAI.""" + +from .chat import OpenAILLM + +__all__ = ["OpenAILLM"] diff --git a/src/crewai/llms/openai/chat.py b/src/crewai/llms/openai/chat.py new file mode 100644 index 000000000..5f00ac4c7 --- /dev/null +++ b/src/crewai/llms/openai/chat.py @@ -0,0 +1,529 @@ +import json +import os +from typing import Any, Dict, List, Optional, Union, Type, Literal +from openai import OpenAI +from pydantic import BaseModel + +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.events import crewai_event_bus +from crewai.utilities.events.llm_events import ( + LLMCallCompletedEvent, + LLMCallFailedEvent, + LLMCallStartedEvent, + LLMCallType, +) +from crewai.utilities.events.tool_usage_events import ( + ToolUsageStartedEvent, + ToolUsageFinishedEvent, + ToolUsageErrorEvent, +) +from datetime import datetime + + +class OpenAILLM(BaseLLM): + """OpenAI LLM implementation with full LLM class compatibility.""" + + def __init__( + self, + model: str = "gpt-4", + timeout: Optional[Union[float, int]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + n: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + presence_penalty: Optional[float] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[int, float]] = None, + response_format: Optional[Type[BaseModel]] = None, + seed: Optional[int] = None, + logprobs: Optional[int] = None, + top_logprobs: Optional[int] = None, + base_url: Optional[str] = None, + api_base: Optional[str] = None, + api_version: Optional[str] = None, + api_key: Optional[str] = None, + callbacks: List[Any] = [], + reasoning_effort: Optional[Literal["none", "low", "medium", "high"]] = None, + stream: bool = False, + max_retries: int = 2, + **kwargs, + ): + """Initialize OpenAI LLM with full compatibility. + + Args: + model: OpenAI model name (e.g., 'gpt-4', 'gpt-3.5-turbo') + timeout: Request timeout in seconds + temperature: Sampling temperature (0-2) + top_p: Nucleus sampling parameter + n: Number of completions to generate + stop: Stop sequences + max_completion_tokens: Maximum tokens in completion + max_tokens: Maximum tokens (legacy parameter) + presence_penalty: Presence penalty (-2 to 2) + frequency_penalty: Frequency penalty (-2 to 2) + logit_bias: Logit bias dictionary + response_format: Pydantic model for structured output + seed: Random seed for deterministic output + logprobs: Whether to return log probabilities + top_logprobs: Number of most likely tokens to return + base_url: Custom API base URL + api_base: Legacy API base parameter + api_version: API version (for Azure) + api_key: OpenAI API key + callbacks: List of callback functions + reasoning_effort: Reasoning effort for o1 models + stream: Whether to stream responses + max_retries: Number of retries for failed requests + **kwargs: Additional parameters + """ + super().__init__(model=model, temperature=temperature) + + # Store all parameters for compatibility + self.timeout = timeout + self.top_p = top_p + self.n = n + self.max_completion_tokens = max_completion_tokens + self.max_tokens = max_tokens or max_completion_tokens + self.presence_penalty = presence_penalty + self.frequency_penalty = frequency_penalty + self.logit_bias = logit_bias + self.response_format = response_format + self.seed = seed + self.logprobs = logprobs + self.top_logprobs = top_logprobs + self.api_base = api_base or base_url + self.base_url = base_url or api_base + self.api_version = api_version + self.api_key = api_key or os.getenv("OPENAI_API_KEY") + self.callbacks = callbacks + self.reasoning_effort = reasoning_effort + self.stream = stream + self.additional_params = kwargs + self.context_window_size = 0 + + # Normalize stop parameter to match LLM class behavior + if stop is None: + self.stop: List[str] = [] + elif isinstance(stop, str): + self.stop = [stop] + else: + self.stop = stop + + # Initialize OpenAI client + self.client = OpenAI( + api_key=self.api_key, + base_url=self.base_url, + timeout=self.timeout, + max_retries=max_retries, + **kwargs, + ) + + self.model_config = self._get_model_config() + + def _get_model_config(self) -> Dict[str, Any]: + """Get model-specific configuration.""" + # Enhanced model configurations matching current LLM_CONTEXT_WINDOW_SIZES + model_configs = { + "gpt-4": {"context_window": 8192, "supports_tools": True}, + "gpt-4o": {"context_window": 128000, "supports_tools": True}, + "gpt-4o-mini": {"context_window": 200000, "supports_tools": True}, + "gpt-4-turbo": {"context_window": 128000, "supports_tools": True}, + "gpt-4.1": {"context_window": 1047576, "supports_tools": True}, + "gpt-4.1-mini": {"context_window": 1047576, "supports_tools": True}, + "gpt-4.1-nano": {"context_window": 1047576, "supports_tools": True}, + "gpt-3.5-turbo": {"context_window": 16385, "supports_tools": True}, + "o1-preview": {"context_window": 128000, "supports_tools": False}, + "o1-mini": {"context_window": 128000, "supports_tools": False}, + "o3-mini": {"context_window": 200000, "supports_tools": False}, + "o4-mini": {"context_window": 200000, "supports_tools": False}, + } + + # Default config if model not found + default_config = {"context_window": 4096, "supports_tools": True} + + for model_prefix, config in model_configs.items(): + if self.model.startswith(model_prefix): + return config + + return default_config + + def _format_messages( + self, messages: Union[str, List[Dict[str, str]]] + ) -> List[Dict[str, str]]: + """Format messages for OpenAI API. + + Args: + messages: Input messages as string or list of dicts + + Returns: + List of properly formatted message dicts + """ + if isinstance(messages, str): + return [{"role": "user", "content": messages}] + + # Validate message format + for msg in messages: + if not isinstance(msg, dict) or "role" not in msg or "content" not in msg: + raise ValueError( + "Each message must be a dict with 'role' and 'content' keys" + ) + + # Handle O1 model special case (system messages not supported) + if "o1" in self.model.lower(): + formatted_messages = [] + for msg in messages: + if msg["role"] == "system": + # Convert system messages to assistant messages for O1 + formatted_messages.append( + {"role": "assistant", "content": msg["content"]} + ) + else: + formatted_messages.append(msg) + return formatted_messages + + return messages + + def _format_tools(self, tools: Optional[List[dict]]) -> Optional[List[dict]]: + """Format tools for OpenAI function calling. + + Args: + tools: List of tool definitions + + Returns: + OpenAI-formatted tool definitions + """ + if not tools or not self.model_config.get("supports_tools", True): + return None + + formatted_tools = [] + for tool in tools: + # Convert to OpenAI tool format + formatted_tool = { + "type": "function", + "function": { + "name": tool.get("name", ""), + "description": tool.get("description", ""), + "parameters": tool.get("parameters", {}), + }, + } + formatted_tools.append(formatted_tool) + + return formatted_tools + + def _handle_tool_calls( + self, + response, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Any: + """Handle tool calls from OpenAI response. + + Args: + response: OpenAI API response + available_functions: Dict mapping function names to callables + from_task: Optional task context + from_agent: Optional agent context + + Returns: + Result of function execution or error message + """ + message = response.choices[0].message + + if not message.tool_calls or not available_functions: + return message.content + + # Execute the first tool call + tool_call = message.tool_calls[0] + function_name = tool_call.function.name + function_args = {} + + if function_name not in available_functions: + return f"Error: Function '{function_name}' not found in available functions" + + try: + # Parse function arguments + function_args = json.loads(tool_call.function.arguments) + fn = available_functions[function_name] + + # Execute function with event tracking + assert hasattr(crewai_event_bus, "emit") + started_at = datetime.now() + crewai_event_bus.emit( + self, + event=ToolUsageStartedEvent( + tool_name=function_name, + tool_args=function_args, + ), + ) + + result = fn(**function_args) + + crewai_event_bus.emit( + self, + event=ToolUsageFinishedEvent( + output=result, + tool_name=function_name, + tool_args=function_args, + started_at=started_at, + finished_at=datetime.now(), + ), + ) + + # Emit success event + event_data = { + "response": result, + "call_type": LLMCallType.TOOL_CALL, + "model": self.model, + } + if from_task is not None: + event_data["from_task"] = from_task + if from_agent is not None: + event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**event_data), + ) + + return result + + except json.JSONDecodeError as e: + error_msg = f"Error parsing function arguments: {e}" + crewai_event_bus.emit( + self, + event=ToolUsageErrorEvent( + tool_name=function_name, + tool_args=function_args, + error=error_msg, + ), + ) + return error_msg + except Exception as e: + error_msg = f"Error executing function '{function_name}': {e}" + crewai_event_bus.emit( + self, + event=ToolUsageErrorEvent( + tool_name=function_name, + tool_args=function_args, + error=error_msg, + ), + ) + return error_msg + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + from_task: Optional[Any] = None, + from_agent: Optional[Any] = None, + ) -> Union[str, Any]: + """Call OpenAI API with the given messages. + + Args: + messages: Input messages for the LLM + tools: Optional list of tool schemas + callbacks: Optional callbacks to execute + available_functions: Optional dict of available functions + from_task: Optional task context + from_agent: Optional agent context + + Returns: + LLM response or tool execution result + + Raises: + ValueError: If messages format is invalid + RuntimeError: If API call fails + """ + # Emit call started event + print("calling from native openai", messages) + assert hasattr(crewai_event_bus, "emit") + + # Prepare event data + started_event_data = { + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "model": self.model, + } + if from_task is not None: + started_event_data["from_task"] = from_task + if from_agent is not None: + started_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallStartedEvent(**started_event_data), + ) + + try: + # Format messages + formatted_messages = self._format_messages(messages) + + # Prepare API call parameters + api_params = { + "model": self.model, + "messages": formatted_messages, + } + + # Add optional parameters + if self.temperature is not None: + api_params["temperature"] = self.temperature + + if self.top_p is not None: + api_params["top_p"] = self.top_p + + if self.n is not None: + api_params["n"] = self.n + + if self.max_tokens is not None: + api_params["max_tokens"] = self.max_tokens + + if self.presence_penalty is not None: + api_params["presence_penalty"] = self.presence_penalty + + if self.frequency_penalty is not None: + api_params["frequency_penalty"] = self.frequency_penalty + + if self.logit_bias is not None: + api_params["logit_bias"] = self.logit_bias + + if self.seed is not None: + api_params["seed"] = self.seed + + if self.logprobs is not None: + api_params["logprobs"] = self.logprobs + + if self.top_logprobs is not None: + api_params["top_logprobs"] = self.top_logprobs + + if self.stop: + api_params["stop"] = self.stop + + if self.response_format is not None: + # Handle structured output for Pydantic models + if hasattr(self.response_format, "model_json_schema"): + api_params["response_format"] = { + "type": "json_schema", + "json_schema": { + "name": self.response_format.__name__, + "schema": self.response_format.model_json_schema(), + "strict": True, + }, + } + else: + api_params["response_format"] = self.response_format + + if self.reasoning_effort is not None and "o1" in self.model: + api_params["reasoning_effort"] = self.reasoning_effort + + # Add tools if provided and supported + formatted_tools = self._format_tools(tools) + if formatted_tools: + api_params["tools"] = formatted_tools + api_params["tool_choice"] = "auto" + + # Execute callbacks before API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_start"): + callback.on_llm_start( + serialized={"name": self.__class__.__name__}, + prompts=[str(formatted_messages)], + ) + + # Make API call + if self.stream: + response = self.client.chat.completions.create( + stream=True, **api_params + ) + # Handle streaming (simplified for now) + full_response = "" + for chunk in response: + if ( + hasattr(chunk.choices[0].delta, "content") + and chunk.choices[0].delta.content + ): + full_response += chunk.choices[0].delta.content + result = full_response + else: + response = self.client.chat.completions.create(**api_params) + # Handle tool calls if present + result = self._handle_tool_calls( + response, available_functions, from_task, from_agent + ) + + # If no tool calls, return text content + if result == response.choices[0].message.content: + result = response.choices[0].message.content or "" + + # Execute callbacks after API call + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_end"): + callback.on_llm_end(response=result) + + # Emit completion event + completion_event_data = { + "messages": formatted_messages, + "response": result, + "call_type": LLMCallType.LLM_CALL, + "model": self.model, + } + if from_task is not None: + completion_event_data["from_task"] = from_task + if from_agent is not None: + completion_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallCompletedEvent(**completion_event_data), + ) + + return result + + except Exception as e: + # Execute error callbacks + if callbacks: + for callback in callbacks: + if hasattr(callback, "on_llm_error"): + callback.on_llm_error(error=e) + + # Emit failed event + failed_event_data = { + "error": str(e), + } + if from_task is not None: + failed_event_data["from_task"] = from_task + if from_agent is not None: + failed_event_data["from_agent"] = from_agent + + crewai_event_bus.emit( + self, + event=LLMCallFailedEvent(**failed_event_data), + ) + + raise RuntimeError(f"OpenAI API call failed: {str(e)}") from e + + def supports_stop_words(self) -> bool: + """Check if OpenAI models support stop words.""" + return True + + def get_context_window_size(self) -> int: + """Get the context window size for the current model.""" + if self.context_window_size != 0: + return self.context_window_size + + # Use 85% of the context window like the original LLM class + context_window = self.model_config.get("context_window", 4096) + self.context_window_size = int(context_window * 0.85) + return self.context_window_size + + def supports_function_calling(self) -> bool: + """Check if the current model supports function calling.""" + return self.model_config.get("supports_tools", True) diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 3998a9bce..09717b528 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -7,39 +7,71 @@ from crewai.llm import LLM, BaseLLM def create_llm( llm_value: Union[str, LLM, Any, None] = None, + prefer_native: Optional[bool] = None, ) -> Optional[LLM | BaseLLM]: """ Creates or returns an LLM instance based on the given llm_value. + Now supports provider prefixes like 'openai/gpt-4' for native implementations. Args: llm_value (str | BaseLLM | Any | None): - - str: The model name (e.g., "gpt-4"). + - str: The model name (e.g., "gpt-4" or "openai/gpt-4"). - BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. + prefer_native (bool | None): + - True: Use native provider implementations when available + - False: Always use LiteLLM implementation + - None: Use environment variable CREWAI_PREFER_NATIVE_LLMS (default: True) + - Note: Provider prefixes (openai/, anthropic/) override this setting Returns: A BaseLLM instance if successful, or None if something fails. + + Examples: + create_llm("gpt-4") # Uses LiteLLM or native based on prefer_native + create_llm("openai/gpt-4") # Always uses native OpenAI implementation + create_llm("anthropic/claude-3-sonnet") # Future: native Anthropic """ # 1) If llm_value is already a BaseLLM or LLM object, return it directly if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM): return llm_value - # 2) If llm_value is a string (model name) + # 2) Determine if we should prefer native implementations (unless provider prefix is used) + if prefer_native is None: + prefer_native = os.getenv("CREWAI_PREFER_NATIVE_LLMS", "true").lower() in ( + "true", + "1", + "yes", + ) + + # 3) If llm_value is a string (model name) if isinstance(llm_value, str): try: + # Provider prefix (openai/, anthropic/) always takes precedence + if "/" in llm_value: + created_llm = LLM(model=llm_value) # LLM class handles routing + return created_llm + + # Try native implementation first if preferred and no prefix + if prefer_native: + native_llm = _create_native_llm(llm_value) + if native_llm: + return native_llm + + # Fallback to LiteLLM created_llm = LLM(model=llm_value) return created_llm except Exception as e: print(f"Failed to instantiate LLM with model='{llm_value}': {e}") return None - # 3) If llm_value is None, parse environment variables or use default + # 4) If llm_value is None, parse environment variables or use default if llm_value is None: - return _llm_via_environment_or_fallback() + return _llm_via_environment_or_fallback(prefer_native) - # 4) Otherwise, attempt to extract relevant attributes from an unknown object + # 5) Otherwise, attempt to extract relevant attributes from an unknown object try: # Extract attributes with explicit types model = ( @@ -48,6 +80,8 @@ def create_llm( or getattr(llm_value, "deployment_name", None) or str(llm_value) ) + + # Extract other parameters temperature: Optional[float] = getattr(llm_value, "temperature", None) max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None) logprobs: Optional[int] = getattr(llm_value, "logprobs", None) @@ -56,6 +90,7 @@ def create_llm( base_url: Optional[str] = getattr(llm_value, "base_url", None) api_base: Optional[str] = getattr(llm_value, "api_base", None) + # Use LLM class constructor which handles routing created_llm = LLM( model=model, temperature=temperature, @@ -72,9 +107,94 @@ def create_llm( return None -def _llm_via_environment_or_fallback() -> Optional[LLM]: +def _create_native_llm(model: str, **kwargs) -> Optional[BaseLLM]: + """ + Create a native LLM implementation based on the model name. + + Args: + model: The model name (e.g., 'gpt-4', 'claude-3-sonnet') + **kwargs: Additional parameters for the LLM + + Returns: + Native LLM instance if supported, None otherwise + """ + try: + # OpenAI models + if _is_openai_model(model): + from crewai.llms.openai import OpenAILLM + + return OpenAILLM(model=model, **kwargs) + + # Claude models + if _is_claude_model(model): + from crewai.llms.anthropic import ClaudeLLM + + return ClaudeLLM(model=model, **kwargs) + + # Gemini models + if _is_gemini_model(model): + from crewai.llms.google import GeminiLLM + + return GeminiLLM(model=model, **kwargs) + + # No native implementation found + return None + + except Exception as e: + print(f"Failed to create native LLM for model '{model}': {e}") + return None + + +def _is_openai_model(model: str) -> bool: + """Check if a model is from OpenAI.""" + openai_prefixes = ( + "gpt-", + "text-davinci", + "text-curie", + "text-babbage", + "text-ada", + "davinci", + "curie", + "babbage", + "ada", + "o1-", + "o3-", + "o4-", + "chatgpt-", + ) + + model_lower = model.lower() + return any(model_lower.startswith(prefix) for prefix in openai_prefixes) + + +def _is_claude_model(model: str) -> bool: + """Check if a model is from Anthropic (Claude).""" + claude_prefixes = ( + "claude-", + "claude", # For cases like just "claude" + ) + + model_lower = model.lower() + return any(model_lower.startswith(prefix) for prefix in claude_prefixes) + + +def _is_gemini_model(model: str) -> bool: + """Check if a model is from Google (Gemini).""" + gemini_prefixes = ( + "gemini-", + "gemini", # For cases like just "gemini" + ) + + model_lower = model.lower() + return any(model_lower.startswith(prefix) for prefix in gemini_prefixes) + + +def _llm_via_environment_or_fallback( + prefer_native: bool = True, +) -> Optional[LLM | BaseLLM]: """ Helper function: if llm_value is None, we load environment variables or fallback default model. + Now with native provider support. """ model_name = ( os.environ.get("MODEL") @@ -83,7 +203,13 @@ def _llm_via_environment_or_fallback() -> Optional[LLM]: or DEFAULT_LLM_MODEL ) - # Initialize parameters with correct types + # Try native implementation first if preferred + if prefer_native: + native_llm = _create_native_llm(model_name) + if native_llm: + return native_llm + + # Initialize parameters with correct types (original logic continues) model: str = model_name temperature: Optional[float] = None max_tokens: Optional[int] = None