mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
Add support for custom LLM implementations
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -4,6 +4,7 @@ import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
|
||||
|
||||
@@ -34,6 +35,78 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
"""Abstract base class for LLM implementations.
|
||||
|
||||
This class defines the interface that all LLM implementations must follow.
|
||||
Users can extend this class to create custom LLM implementations that don't
|
||||
rely on litellm's authentication mechanism.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
"""
|
||||
self.stop = []
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> Union[str, Any]:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
If string, it will be converted to a single user message.
|
||||
If list, each dict must have 'role' and 'content' keys.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
Each tool should define its name, description, and parameters.
|
||||
callbacks: Optional list of callback functions to be executed
|
||||
during and after the LLM call.
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
the result of a tool function call (Any).
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
"""
|
||||
pass
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
self._original_stream = original_stream
|
||||
@@ -126,7 +199,7 @@ def suppress_warnings():
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class LLM:
|
||||
class LLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
||||
Reference in New Issue
Block a user