Files
crewAI/src/crewai/utilities/llm_utils.py
devin-ai-integration[bot] 807c13e144 Add support for custom LLM implementations (#2277)
* Add support for custom LLM implementations

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix import sorting and type annotations

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix linting issues with import sorting

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix type errors in crew.py by updating tool-related methods to return List[BaseTool]

Co-Authored-By: Joe Moura <joao@crewai.com>

* Enhance custom LLM implementation with better error handling, documentation, and test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>

* Refactor LLM module by extracting BaseLLM to a separate file

This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include:

- Creating a new file src/crewai/llms/base_llm.py
- Moving the BaseLLM class to the new file
- Updating imports in __init__.py and llm.py to reflect the new location
- Updating test cases to use the new import path

The refactoring maintains the existing functionality while improving the project's module structure.

* Add AISuite LLM support and update dependencies

- Integrate AISuite as a new third-party LLM option
- Update pyproject.toml and uv.lock to include aisuite package
- Modify BaseLLM to support more flexible initialization
- Remove unnecessary LLM imports across multiple files
- Implement AISuiteLLM with basic chat completion functionality

* Update AISuiteLLM and LLM utility type handling

- Modify AISuiteLLM to support more flexible input types for messages
- Update type hints in AISuiteLLM to allow string or list of message dictionaries
- Enhance LLM utility function to support broader LLM type annotations
- Remove default `self.stop` attribute from BaseLLM initialization

* Update LLM imports and type hints across multiple files

- Modify imports in crew_chat.py to use LLM instead of BaseLLM
- Update type hints in llm_utils.py to use LLM type
- Add optional `stop` parameter to BaseLLM initialization
- Refactor type handling for LLM creation and usage

* Improve stop words handling in CrewAgentExecutor

- Add support for handling existing stop words in LLM configuration
- Ensure stop words are correctly merged and deduplicated
- Update type hints to support both LLM and BaseLLM types

* Remove abstract method set_callbacks from BaseLLM class

* Enhance CustomLLM and JWTAuthLLM initialization with model parameter

- Update CustomLLM to accept a model parameter during initialization
- Modify test cases to include the new model argument
- Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors
- Update type hints in create_llm function to support both LLM and BaseLLM types

* Enhance create_llm function to support BaseLLM type

- Update the create_llm function to accept both LLM and BaseLLM instances
- Ensure compatibility with existing LLM handling logic

* Update type hint for initialize_chat_llm to support BaseLLM

- Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances
- Ensure compatibility with recent changes in create_llm function

* Refactor AISuiteLLM to include tools parameter in completion methods

- Update the _prepare_completion_params method to accept an optional tools parameter
- Modify the chat completion method to utilize the new tools parameter for enhanced functionality
- Clean up print statements for better code clarity

* Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code.

* Refactor Crew class and LLM hierarchy for improved type handling and code clarity

- Update Crew class methods to enhance readability with consistent formatting and type hints.
- Change LLM class to inherit from BaseLLM for better structure.
- Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor.
- Adjust BaseLLM to provide default implementations for stop words and context window size methods.
- Clean up AISuiteLLM by removing unused methods related to stop words and context window size.

* Remove unused `stream` method from `BaseLLM` class to enhance code clarity and maintainability.

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com>
Co-authored-by: João Moura <joaomdmoura@gmail.com>
Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
2025-03-25 12:39:08 -04:00

197 lines
6.8 KiB
Python

import os
from typing import Any, Dict, List, Optional, Union
from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS
from crewai.llm import LLM, BaseLLM
def create_llm(
llm_value: Union[str, LLM, Any, None] = None,
) -> Optional[LLM | BaseLLM]:
"""
Creates or returns an LLM instance based on the given llm_value.
Args:
llm_value (str | BaseLLM | Any | None):
- str: The model name (e.g., "gpt-4").
- BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is.
- Any: Attempt to extract known attributes like model_name, temperature, etc.
- None: Use environment-based or fallback default model.
Returns:
A BaseLLM instance if successful, or None if something fails.
"""
# 1) If llm_value is already a BaseLLM or LLM object, return it directly
if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM):
return llm_value
# 2) If llm_value is a string (model name)
if isinstance(llm_value, str):
try:
created_llm = LLM(model=llm_value)
return created_llm
except Exception as e:
print(f"Failed to instantiate LLM with model='{llm_value}': {e}")
return None
# 3) If llm_value is None, parse environment variables or use default
if llm_value is None:
return _llm_via_environment_or_fallback()
# 4) Otherwise, attempt to extract relevant attributes from an unknown object
try:
# Extract attributes with explicit types
model = (
getattr(llm_value, "model", None)
or getattr(llm_value, "model_name", None)
or getattr(llm_value, "deployment_name", None)
or str(llm_value)
)
temperature: Optional[float] = getattr(llm_value, "temperature", None)
max_tokens: Optional[int] = getattr(llm_value, "max_tokens", None)
logprobs: Optional[int] = getattr(llm_value, "logprobs", None)
timeout: Optional[float] = getattr(llm_value, "timeout", None)
api_key: Optional[str] = getattr(llm_value, "api_key", None)
base_url: Optional[str] = getattr(llm_value, "base_url", None)
api_base: Optional[str] = getattr(llm_value, "api_base", None)
created_llm = LLM(
model=model,
temperature=temperature,
max_tokens=max_tokens,
logprobs=logprobs,
timeout=timeout,
api_key=api_key,
base_url=base_url,
api_base=api_base,
)
return created_llm
except Exception as e:
print(f"Error instantiating LLM from unknown object type: {e}")
return None
def _llm_via_environment_or_fallback() -> Optional[LLM]:
"""
Helper function: if llm_value is None, we load environment variables or fallback default model.
"""
model_name = (
os.environ.get("MODEL")
or os.environ.get("MODEL_NAME")
or os.environ.get("OPENAI_MODEL_NAME")
or DEFAULT_LLM_MODEL
)
# Initialize parameters with correct types
model: str = model_name
temperature: Optional[float] = None
max_tokens: Optional[int] = None
max_completion_tokens: Optional[int] = None
logprobs: Optional[int] = None
timeout: Optional[float] = None
api_key: Optional[str] = None
base_url: Optional[str] = None
api_version: Optional[str] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stop: Optional[Union[str, List[str]]] = None
logit_bias: Optional[Dict[int, float]] = None
response_format: Optional[Dict[str, Any]] = None
seed: Optional[int] = None
top_logprobs: Optional[int] = None
callbacks: List[Any] = []
# Optional base URL from env
base_url = (
os.environ.get("BASE_URL")
or os.environ.get("OPENAI_API_BASE")
or os.environ.get("OPENAI_BASE_URL")
)
api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE")
# Synchronize base_url and api_base if one is populated and the other is not
if base_url and not api_base:
api_base = base_url
elif api_base and not base_url:
base_url = api_base
# Initialize llm_params dictionary
llm_params: Dict[str, Any] = {
"model": model,
"temperature": temperature,
"max_tokens": max_tokens,
"max_completion_tokens": max_completion_tokens,
"logprobs": logprobs,
"timeout": timeout,
"api_key": api_key,
"base_url": base_url,
"api_base": api_base,
"api_version": api_version,
"presence_penalty": presence_penalty,
"frequency_penalty": frequency_penalty,
"top_p": top_p,
"n": n,
"stop": stop,
"logit_bias": logit_bias,
"response_format": response_format,
"seed": seed,
"top_logprobs": top_logprobs,
"callbacks": callbacks,
}
UNACCEPTED_ATTRIBUTES = [
"AWS_ACCESS_KEY_ID",
"AWS_SECRET_ACCESS_KEY",
"AWS_REGION_NAME",
]
set_provider = model_name.split("/")[0] if "/" in model_name else "openai"
if set_provider in ENV_VARS:
env_vars_for_provider = ENV_VARS[set_provider]
if isinstance(env_vars_for_provider, (list, tuple)):
for env_var in env_vars_for_provider:
key_name = env_var.get("key_name")
if key_name and key_name not in UNACCEPTED_ATTRIBUTES:
env_value = os.environ.get(key_name)
if env_value:
# Map environment variable names to recognized parameters
param_key = _normalize_key_name(key_name.lower())
llm_params[param_key] = env_value
elif isinstance(env_var, dict):
if env_var.get("default", False):
for key, value in env_var.items():
if key not in ["prompt", "key_name", "default"]:
llm_params[key.lower()] = value
else:
print(
f"Expected env_var to be a dictionary, but got {type(env_var)}"
)
# Remove None values
llm_params = {k: v for k, v in llm_params.items() if v is not None}
# Try creating the LLM
try:
new_llm = LLM(**llm_params)
return new_llm
except Exception as e:
print(
f"Error instantiating LLM from environment/fallback: {type(e).__name__}: {e}"
)
return None
def _normalize_key_name(key_name: str) -> str:
"""
Maps environment variable names to recognized litellm parameter keys,
using patterns from LITELLM_PARAMS.
"""
for pattern in LITELLM_PARAMS:
if pattern in key_name:
return pattern
return key_name