Files
crewAI/docs/custom_llm.md
devin-ai-integration[bot] 807c13e144 Add support for custom LLM implementations (#2277)
* Add support for custom LLM implementations

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix import sorting and type annotations

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix linting issues with import sorting

Co-Authored-By: Joe Moura <joao@crewai.com>

* Fix type errors in crew.py by updating tool-related methods to return List[BaseTool]

Co-Authored-By: Joe Moura <joao@crewai.com>

* Enhance custom LLM implementation with better error handling, documentation, and test coverage

Co-Authored-By: Joe Moura <joao@crewai.com>

* Refactor LLM module by extracting BaseLLM to a separate file

This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include:

- Creating a new file src/crewai/llms/base_llm.py
- Moving the BaseLLM class to the new file
- Updating imports in __init__.py and llm.py to reflect the new location
- Updating test cases to use the new import path

The refactoring maintains the existing functionality while improving the project's module structure.

* Add AISuite LLM support and update dependencies

- Integrate AISuite as a new third-party LLM option
- Update pyproject.toml and uv.lock to include aisuite package
- Modify BaseLLM to support more flexible initialization
- Remove unnecessary LLM imports across multiple files
- Implement AISuiteLLM with basic chat completion functionality

* Update AISuiteLLM and LLM utility type handling

- Modify AISuiteLLM to support more flexible input types for messages
- Update type hints in AISuiteLLM to allow string or list of message dictionaries
- Enhance LLM utility function to support broader LLM type annotations
- Remove default `self.stop` attribute from BaseLLM initialization

* Update LLM imports and type hints across multiple files

- Modify imports in crew_chat.py to use LLM instead of BaseLLM
- Update type hints in llm_utils.py to use LLM type
- Add optional `stop` parameter to BaseLLM initialization
- Refactor type handling for LLM creation and usage

* Improve stop words handling in CrewAgentExecutor

- Add support for handling existing stop words in LLM configuration
- Ensure stop words are correctly merged and deduplicated
- Update type hints to support both LLM and BaseLLM types

* Remove abstract method set_callbacks from BaseLLM class

* Enhance CustomLLM and JWTAuthLLM initialization with model parameter

- Update CustomLLM to accept a model parameter during initialization
- Modify test cases to include the new model argument
- Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors
- Update type hints in create_llm function to support both LLM and BaseLLM types

* Enhance create_llm function to support BaseLLM type

- Update the create_llm function to accept both LLM and BaseLLM instances
- Ensure compatibility with existing LLM handling logic

* Update type hint for initialize_chat_llm to support BaseLLM

- Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances
- Ensure compatibility with recent changes in create_llm function

* Refactor AISuiteLLM to include tools parameter in completion methods

- Update the _prepare_completion_params method to accept an optional tools parameter
- Modify the chat completion method to utilize the new tools parameter for enhanced functionality
- Clean up print statements for better code clarity

* Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code.

* Refactor Crew class and LLM hierarchy for improved type handling and code clarity

- Update Crew class methods to enhance readability with consistent formatting and type hints.
- Change LLM class to inherit from BaseLLM for better structure.
- Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor.
- Adjust BaseLLM to provide default implementations for stop words and context window size methods.
- Clean up AISuiteLLM by removing unused methods related to stop words and context window size.

* Remove unused `stream` method from `BaseLLM` class to enhance code clarity and maintainability.

---------

Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com>
Co-authored-by: Joe Moura <joao@crewai.com>
Co-authored-by: Lorenze Jay <lorenzejaytech@gmail.com>
Co-authored-by: João Moura <joaomdmoura@gmail.com>
Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com>
2025-03-25 12:39:08 -04:00

22 KiB

Custom LLM Implementations

CrewAI now supports custom LLM implementations through the BaseLLM abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism.

Using Custom LLM Implementations

To create a custom LLM implementation, you need to:

  1. Inherit from the BaseLLM abstract base class
  2. Implement the required methods:
    • call(): The main method to call the LLM with messages
    • supports_function_calling(): Whether the LLM supports function calling
    • supports_stop_words(): Whether the LLM supports stop words
    • get_context_window_size(): The context window size of the LLM

Example: Basic Custom LLM

from crewai import BaseLLM
from typing import Any, Dict, List, Optional, Union

class CustomLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not api_key or not isinstance(api_key, str):
            raise ValueError("Invalid API key: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.api_key = api_key
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with the given messages.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM
        # For example, using requests:
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.api_key}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint, 
                headers=headers, 
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        # Return True if your LLM supports function calling
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        # Return True if your LLM supports stop words
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        # Return the context window size of your LLM
        return 8192

Error Handling Best Practices

When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices:

1. Implement Try-Except Blocks for API Calls

Always wrap API calls in try-except blocks to handle different types of errors:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        # API call implementation
        response = requests.post(
            self.endpoint,
            headers=self.headers,
            json=self.prepare_payload(messages),
            timeout=30  # Set a reasonable timeout
        )
        response.raise_for_status()  # Raise an exception for HTTP errors
        return response.json()["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

2. Implement Retry Logic for Transient Failures

For transient failures like network issues or rate limiting, implement retry logic with exponential backoff:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import time
    
    max_retries = 3
    retry_delay = 1  # seconds
    
    for attempt in range(max_retries):
        try:
            response = requests.post(
                self.endpoint,
                headers=self.headers,
                json=self.prepare_payload(messages),
                timeout=30
            )
            response.raise_for_status()
            return response.json()["choices"][0]["message"]["content"]
        except (requests.Timeout, requests.ConnectionError) as e:
            if attempt < max_retries - 1:
                time.sleep(retry_delay * (2 ** attempt))  # Exponential backoff
                continue
            raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")

3. Validate Input Parameters

Always validate input parameters to prevent runtime errors:

def __init__(self, api_key: str, endpoint: str):
    super().__init__()
    if not api_key or not isinstance(api_key, str):
        raise ValueError("Invalid API key: must be a non-empty string")
    if not endpoint or not isinstance(endpoint, str):
        raise ValueError("Invalid endpoint URL: must be a non-empty string")
    self.api_key = api_key
    self.endpoint = endpoint

4. Handle Authentication Errors Gracefully

Provide clear error messages for authentication failures:

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    try:
        response = requests.post(self.endpoint, headers=self.headers, json=data)
        if response.status_code == 401:
            raise ValueError("Authentication failed: Invalid API key or token")
        elif response.status_code == 403:
            raise ValueError("Authorization failed: Insufficient permissions")
        response.raise_for_status()
        # Process response
    except Exception as e:
        # Handle error
        raise

Example: JWT-based Authentication

For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this:

from crewai import BaseLLM, Agent, Task
from typing import Any, Dict, List, Optional, Union

class JWTAuthLLM(BaseLLM):
    def __init__(self, jwt_token: str, endpoint: str):
        super().__init__()  # Initialize the base class to set default attributes
        if not jwt_token or not isinstance(jwt_token, str):
            raise ValueError("Invalid JWT token: must be a non-empty string")
        if not endpoint or not isinstance(endpoint, str):
            raise ValueError("Invalid endpoint URL: must be a non-empty string")
        self.jwt_token = jwt_token
        self.endpoint = endpoint
        self.stop = []  # You can customize stop words if needed
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        """Call the LLM with JWT authentication.
        
        Args:
            messages: Input messages for the LLM.
            tools: Optional list of tool schemas for function calling.
            callbacks: Optional list of callback functions.
            available_functions: Optional dict mapping function names to callables.
            
        Returns:
            Either a text response from the LLM or the result of a tool function call.
            
        Raises:
            TimeoutError: If the LLM request times out.
            RuntimeError: If the LLM request fails for other reasons.
            ValueError: If the response format is invalid.
        """
        # Implement your own logic to call the LLM with JWT authentication
        import requests
        
        try:
            headers = {
                "Authorization": f"Bearer {self.jwt_token}",
                "Content-Type": "application/json"
            }
            
            # Convert string message to proper format if needed
            if isinstance(messages, str):
                messages = [{"role": "user", "content": messages}]
            
            data = {
                "messages": messages,
                "tools": tools
            }
            
            response = requests.post(
                self.endpoint,
                headers=headers,
                json=data,
                timeout=30  # Set a reasonable timeout
            )
            
            if response.status_code == 401:
                raise ValueError("Authentication failed: Invalid JWT token")
            elif response.status_code == 403:
                raise ValueError("Authorization failed: Insufficient permissions")
                
            response.raise_for_status()  # Raise an exception for HTTP errors
            return response.json()["choices"][0]["message"]["content"]
        except requests.Timeout:
            raise TimeoutError("LLM request timed out")
        except requests.RequestException as e:
            raise RuntimeError(f"LLM request failed: {str(e)}")
        except (KeyError, IndexError, ValueError) as e:
            raise ValueError(f"Invalid response format: {str(e)}")
        
    def supports_function_calling(self) -> bool:
        """Check if the LLM supports function calling.
        
        Returns:
            True if the LLM supports function calling, False otherwise.
        """
        return True
        
    def supports_stop_words(self) -> bool:
        """Check if the LLM supports stop words.
        
        Returns:
            True if the LLM supports stop words, False otherwise.
        """
        return True
        
    def get_context_window_size(self) -> int:
        """Get the context window size of the LLM.
        
        Returns:
            The context window size as an integer.
        """
        return 8192

Troubleshooting

Here are some common issues you might encounter when implementing custom LLMs and how to resolve them:

1. Authentication Failures

Symptoms: 401 Unauthorized or 403 Forbidden errors

Solutions:

  • Verify that your API key or JWT token is valid and not expired
  • Check that you're using the correct authentication header format
  • Ensure that your token has the necessary permissions

2. Timeout Issues

Symptoms: Requests taking too long or timing out

Solutions:

  • Implement timeout handling as shown in the examples
  • Use retry logic with exponential backoff
  • Consider using a more reliable network connection

3. Response Parsing Errors

Symptoms: KeyError, IndexError, or ValueError when processing responses

Solutions:

  • Validate the response format before accessing nested fields
  • Implement proper error handling for malformed responses
  • Check the API documentation for the expected response format

4. Rate Limiting

Symptoms: 429 Too Many Requests errors

Solutions:

  • Implement rate limiting in your custom LLM
  • Add exponential backoff for retries
  • Consider using a token bucket algorithm for more precise rate control

Advanced Features

Logging

Adding logging to your custom LLM can help with debugging and monitoring:

import logging
from typing import Any, Dict, List, Optional, Union

class LoggingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.logger = logging.getLogger("crewai.llm.custom")
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages")
        try:
            # API call implementation
            response = self._make_api_call(messages, tools)
            self.logger.debug(f"LLM response received: {response[:100]}...")
            return response
        except Exception as e:
            self.logger.error(f"LLM call failed: {str(e)}")
            raise

Rate Limiting

Implementing rate limiting can help avoid overwhelming the LLM API:

import time
from typing import Any, Dict, List, Optional, Union

class RateLimitedLLM(BaseLLM):
    def __init__(
        self, 
        api_key: str, 
        endpoint: str, 
        requests_per_minute: int = 60
    ):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.requests_per_minute = requests_per_minute
        self.request_times: List[float] = []
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        self._enforce_rate_limit()
        # Record this request time
        self.request_times.append(time.time())
        # Make the actual API call
        return self._make_api_call(messages, tools)
        
    def _enforce_rate_limit(self) -> None:
        """Enforce the rate limit by waiting if necessary."""
        now = time.time()
        # Remove request times older than 1 minute
        self.request_times = [t for t in self.request_times if now - t < 60]
        
        if len(self.request_times) >= self.requests_per_minute:
            # Calculate how long to wait
            oldest_request = min(self.request_times)
            wait_time = 60 - (now - oldest_request)
            if wait_time > 0:
                time.sleep(wait_time)

Metrics Collection

Collecting metrics can help you monitor your LLM usage:

import time
from typing import Any, Dict, List, Optional, Union

class MetricsCollectingLLM(BaseLLM):
    def __init__(self, api_key: str, endpoint: str):
        super().__init__()
        self.api_key = api_key
        self.endpoint = endpoint
        self.metrics: Dict[str, Any] = {
            "total_calls": 0,
            "total_tokens": 0,
            "errors": 0,
            "latency": []
        }
        
    def call(
        self,
        messages: Union[str, List[Dict[str, str]]],
        tools: Optional[List[dict]] = None,
        callbacks: Optional[List[Any]] = None,
        available_functions: Optional[Dict[str, Any]] = None,
    ) -> Union[str, Any]:
        start_time = time.time()
        self.metrics["total_calls"] += 1
        
        try:
            response = self._make_api_call(messages, tools)
            # Estimate tokens (simplified)
            if isinstance(messages, str):
                token_estimate = len(messages) // 4
            else:
                token_estimate = sum(len(m.get("content", "")) // 4 for m in messages)
            self.metrics["total_tokens"] += token_estimate
            return response
        except Exception as e:
            self.metrics["errors"] += 1
            raise
        finally:
            latency = time.time() - start_time
            self.metrics["latency"].append(latency)
            
    def get_metrics(self) -> Dict[str, Any]:
        """Return the collected metrics."""
        avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0
        return {
            **self.metrics,
            "avg_latency": avg_latency
        }

Advanced Usage: Function Calling

If your LLM supports function calling, you can implement the function calling logic in your custom LLM:

import json
from typing import Any, Dict, List, Optional, Union

def call(
    self,
    messages: Union[str, List[Dict[str, str]]],
    tools: Optional[List[dict]] = None,
    callbacks: Optional[List[Any]] = None,
    available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
    import requests
    
    try:
        headers = {
            "Authorization": f"Bearer {self.jwt_token}",
            "Content-Type": "application/json"
        }
        
        # Convert string message to proper format if needed
        if isinstance(messages, str):
            messages = [{"role": "user", "content": messages}]
        
        data = {
            "messages": messages,
            "tools": tools
        }
        
        response = requests.post(
            self.endpoint,
            headers=headers,
            json=data,
            timeout=30
        )
        response.raise_for_status()
        response_data = response.json()
        
        # Check if the LLM wants to call a function
        if response_data["choices"][0]["message"].get("tool_calls"):
            tool_calls = response_data["choices"][0]["message"]["tool_calls"]
            
            # Process each tool call
            for tool_call in tool_calls:
                function_name = tool_call["function"]["name"]
                function_args = json.loads(tool_call["function"]["arguments"])
                
                if available_functions and function_name in available_functions:
                    function_to_call = available_functions[function_name]
                    function_response = function_to_call(**function_args)
                    
                    # Add the function response to the messages
                    messages.append({
                        "role": "tool",
                        "tool_call_id": tool_call["id"],
                        "name": function_name,
                        "content": str(function_response)
                    })
            
            # Call the LLM again with the updated messages
            return self.call(messages, tools, callbacks, available_functions)
        
        # Return the text response if no function call
        return response_data["choices"][0]["message"]["content"]
    except requests.Timeout:
        raise TimeoutError("LLM request timed out")
    except requests.RequestException as e:
        raise RuntimeError(f"LLM request failed: {str(e)}")
    except (KeyError, IndexError, ValueError) as e:
        raise ValueError(f"Invalid response format: {str(e)}")

Using Your Custom LLM with CrewAI

Once you've implemented your custom LLM, you can use it with CrewAI agents and crews:

from crewai import Agent, Task, Crew
from typing import Dict, Any

# Create your custom LLM instance
jwt_llm = JWTAuthLLM(
    jwt_token="your.jwt.token", 
    endpoint="https://your-llm-endpoint.com/v1/chat/completions"
)

# Use it with an agent
agent = Agent(
    role="Research Assistant",
    goal="Find information on a topic",
    backstory="You are a research assistant tasked with finding information.",
    llm=jwt_llm,
)

# Create a task for the agent
task = Task(
    description="Research the benefits of exercise",
    agent=agent,
    expected_output="A summary of the benefits of exercise",
)

# Execute the task
result = agent.execute_task(task)
print(result)

# Or use it with a crew
crew = Crew(
    agents=[agent],
    tasks=[task],
    manager_llm=jwt_llm,  # Use your custom LLM for the manager
)

# Run the crew
result = crew.kickoff()
print(result)

Implementing Your Own Authentication Mechanism

The BaseLLM class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use:

  • OAuth tokens
  • Client certificates
  • Custom headers
  • Session-based authentication
  • Any other authentication method required by your LLM provider

Simply implement the appropriate authentication logic in your custom LLM class.