Refactor LLM module by extracting BaseLLM to a separate file

This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include:

- Creating a new file src/crewai/llms/base_llm.py
- Moving the BaseLLM class to the new file
- Updating imports in __init__.py and llm.py to reflect the new location
- Updating test cases to use the new import path

The refactoring maintains the existing functionality while improving the project's module structure.
This commit is contained in:
Lorenze Jay
2025-03-04 15:54:46 -08:00
parent 963ed23b63
commit 709941c4c7
6 changed files with 680 additions and 203 deletions

View File

@@ -4,7 +4,8 @@ from crewai.agent import Agent
from crewai.crew import Crew
from crewai.flow.flow import Flow
from crewai.knowledge.knowledge import Knowledge
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.process import Process
from crewai.task import Task

View File

@@ -4,7 +4,6 @@ import os
import sys
import threading
import warnings
from abc import ABC, abstractmethod
from contextlib import contextmanager
from typing import Any, Dict, List, Literal, Optional, Type, Union, cast
@@ -27,6 +26,7 @@ with warnings.catch_warnings():
from litellm.utils import get_supported_openai_params, supports_response_schema
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.events import crewai_event_bus
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
@@ -35,108 +35,6 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
load_dotenv()
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
rely on litellm's authentication mechanism.
Custom LLM implementations should handle error cases gracefully, including
timeouts, authentication failures, and malformed responses. They should also
implement proper validation for input parameters and provide clear error
messages when things go wrong.
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
def __init__(self):
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
All custom LLM implementations should call super().__init__() to ensure
that these default attributes are properly initialized.
"""
self.stop = []
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
This method should return True if the LLM implementation supports
function calling (tools), and False otherwise. If this method returns
True, the LLM should be able to handle the 'tools' parameter in the
call() method.
Returns:
True if the LLM supports function calling, False otherwise.
"""
pass
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
This method should return True if the LLM implementation supports
stop words, and False otherwise. If this method returns True, the
LLM should respect the 'stop' attribute when generating responses.
Returns:
True if the LLM supports stop words, False otherwise.
"""
pass
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
This method should return the maximum number of tokens that the LLM
can process in a single request. This is used by CrewAI to ensure
that messages don't exceed the LLM's context window.
Returns:
The context window size as an integer.
"""
pass
class FilteredStream:
def __init__(self, original_stream):
self._original_stream = original_stream

104
src/crewai/llms/base_llm.py Normal file
View File

@@ -0,0 +1,104 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
class BaseLLM(ABC):
"""Abstract base class for LLM implementations.
This class defines the interface that all LLM implementations must follow.
Users can extend this class to create custom LLM implementations that don't
rely on litellm's authentication mechanism.
Custom LLM implementations should handle error cases gracefully, including
timeouts, authentication failures, and malformed responses. They should also
implement proper validation for input parameters and provide clear error
messages when things go wrong.
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
"""
def __init__(self):
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
All custom LLM implementations should call super().__init__() to ensure
that these default attributes are properly initialized.
"""
self.stop = []
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Call the LLM with the given messages.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
If string, it will be converted to a single user message.
If list, each dict must have 'role' and 'content' keys.
tools: Optional list of tool schemas for function calling.
Each tool should define its name, description, and parameters.
callbacks: Optional list of callback functions to be executed
during and after the LLM call.
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
Returns:
Either a text response from the LLM (str) or
the result of a tool function call (Any).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
This method should return True if the LLM implementation supports
function calling (tools), and False otherwise. If this method returns
True, the LLM should be able to handle the 'tools' parameter in the
call() method.
Returns:
True if the LLM supports function calling, False otherwise.
"""
pass
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
This method should return True if the LLM implementation supports
stop words, and False otherwise. If this method returns True, the
LLM should respect the 'stop' attribute when generating responses.
Returns:
True if the LLM supports stop words, False otherwise.
"""
pass
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
This method should return the maximum number of tokens that the LLM
can process in a single request. This is used by CrewAI to ensure
that messages don't exceed the LLM's context window.
Returns:
The context window size as an integer.
"""
pass

View File

@@ -0,0 +1,107 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "What is the answer to life, the universe, and everything?"}],
"model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '206'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"id\": \"chatcmpl-B7W6FS0wpfndLdg12G3H6ZAXcYhJi\",\n \"object\":
\"chat.completion\",\n \"created\": 1741131387,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n
\ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\":
\"assistant\",\n \"content\": \"The answer to life, the universe, and
everything, famously found in Douglas Adams' \\\"The Hitchhiker's Guide to the
Galaxy,\\\" is the number 42. However, the question itself is left ambiguous,
leading to much speculation and humor in the story.\",\n \"refusal\":
null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n
\ }\n ],\n \"usage\": {\n \"prompt_tokens\": 30,\n \"completion_tokens\":
54,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\":
0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n
\ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\":
0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\":
\"default\",\n \"system_fingerprint\": \"fp_06737a9306\"\n}\n"
headers:
CF-RAY:
- 91b532234c18cf1f-SJC
Connection:
- keep-alive
Content-Encoding:
- gzip
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:36:28 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=DgLb6UAE6W4Oeto1Bi2RiKXQVV5TTzkXdXWFdmAEwQQ-1741131388-1.0.1.1-jWQtsT95wOeQbmIxAK7cv8gJWxYi1tQ.IupuJzBDnZr7iEChwVUQBRfnYUBJPDsNly3bakCDArjD_S.FLKwH6xUfvlxgfd4YSBhBPy7bcgw;
path=/; expires=Wed, 05-Mar-25 00:06:28 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=Oa59XCmqjKLKwU34la1hkTunN57JW20E.ZHojvRBfow-1741131388236-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
Transfer-Encoding:
- chunked
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '776'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999960'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_97824e8fe7c1aca3fbcba7c925388b39
http_version: HTTP/1.1
status_code: 200
version: 1

View File

@@ -0,0 +1,305 @@
interactions:
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54660799a15b4-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:16 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=OwS.6cyfDpbxxx8vPp4THv5eNoDMQK0qSVN.wSUyOYk-1741132216-1.0.1.1-QBVd08CjfmDBpNnYQM5ILGbTUWKh6SDM9E4ARG4SV2Z9Q4ltFSFLXoo38OGJApUNZmzn4PtRsyAPsHt_dsrHPF6MD17FPcGtrnAHqCjJrfU;
path=/; expires=Wed, 05-Mar-25 00:20:16 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=n_ebDsAOhJm5Mc7OMx8JDiOaZq5qzHCnVxyS3KN0BwA-1741132216951-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '19'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_042a4e8f9432f6fde7a02037bb6caafa
http_version: HTTP/1.1
status_code: 400
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54664bb1acef1-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:17 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=.wGU4pJEajaSzFWjp05TBQwWbCNA2CgpYNu7UYOzbbM-1741132217-1.0.1.1-NoLiAx4qkplllldYYxZCOSQGsX6hsPUJIEyqmt84B3g7hjW1s7.jk9C9PYzXagHWjT0sQ9Ny4LZBA94lDJTfDBZpty8NJQha7ZKW0P_msH8;
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=GAjgJjVLtN49bMeWdWZDYLLkEkK51z5kxK4nKqhAzxY-1741132217161-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '25'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_7a1d027da1ef4468e861e570c72e98fb
http_version: HTTP/1.1
status_code: 400
- request:
body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": [{"role": "system", "content": "You are Say Hi.
You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give
my best complete final answer to the task respond using the exact following
format:\n\nThought: I now can give a great answer\nFinal Answer: Your final
answer must be the great and the most complete as possible, it must be outcome
described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user",
"content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria
for your final answer: A greeting to the user\nyou MUST return the actual complete
content as the final answer, not a summary.\n\nBegin! This is VERY important
to you, use the tools available and give your best Final Answer, your job depends
on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}'
headers:
accept:
- application/json
accept-encoding:
- gzip, deflate
connection:
- keep-alive
content-length:
- '931'
content-type:
- application/json
host:
- api.openai.com
user-agent:
- OpenAI/Python 1.61.0
x-stainless-arch:
- arm64
x-stainless-async:
- 'false'
x-stainless-lang:
- python
x-stainless-os:
- MacOS
x-stainless-package-version:
- 1.61.0
x-stainless-retry-count:
- '0'
x-stainless-runtime:
- CPython
x-stainless-runtime-version:
- 3.12.8
method: POST
uri: https://api.openai.com/v1/chat/completions
response:
content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n
\ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n
\ \"code\": \"missing_required_parameter\"\n }\n}"
headers:
CF-RAY:
- 91b54666183beb22-SJC
Connection:
- keep-alive
Content-Length:
- '219'
Content-Type:
- application/json
Date:
- Tue, 04 Mar 2025 23:50:17 GMT
Server:
- cloudflare
Set-Cookie:
- __cf_bm=VwjWHHpkZMJlosI9RbMqxYDBS1t0JK4tWpAy4lST2QM-1741132217-1.0.1.1-u7PU.ZvVBTXNB5R8vaYfWdPXAjWZ3ZcTAy656VaGDZmKIckk5od._eQdn0W0EGVtEMm3TuF60z4GZAPDwMYvb3_3cw1RuEMmQbp4IIrl7VY;
path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly;
Secure; SameSite=None
- _cfuvid=NglAAsQBoiabMuuHFgilRjflSPFqS38VGKnGyweuCuw-1741132217438-0.0.1.1-604800000;
path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None
X-Content-Type-Options:
- nosniff
access-control-expose-headers:
- X-Request-ID
alt-svc:
- h3=":443"; ma=86400
cf-cache-status:
- DYNAMIC
openai-organization:
- crewai-iuxna1
openai-processing-ms:
- '56'
openai-version:
- '2020-10-01'
strict-transport-security:
- max-age=31536000; includeSubDomains; preload
x-ratelimit-limit-requests:
- '30000'
x-ratelimit-limit-tokens:
- '150000000'
x-ratelimit-remaining-requests:
- '29999'
x-ratelimit-remaining-tokens:
- '149999974'
x-ratelimit-reset-requests:
- 2ms
x-ratelimit-reset-tokens:
- 0s
x-request-id:
- req_3c335b308b82cc2214783a4bf2fc0fd4
http_version: HTTP/1.1
status_code: 400
version: 1

View File

@@ -1,101 +1,153 @@
from typing import Any, Dict, List, Optional, Union
from unittest.mock import Mock
import pytest
from crewai.llm import BaseLLM
from crewai import Agent, Crew, Process, Task
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.llm_utils import create_llm
class CustomLLM(BaseLLM):
"""Custom LLM implementation for testing.
This is a simple implementation of the BaseLLM abstract base class
that returns a predefined response for testing purposes.
"""
def __init__(self, response: str = "Custom LLM response"):
def __init__(self, response="Default response"):
"""Initialize the CustomLLM with a predefined response.
Args:
response: The predefined response to return from call().
"""
super().__init__()
self.response = response
self.calls = []
self.stop = []
self.call_count = 0
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return the predefined response.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The predefined response string.
messages,
tools=None,
callbacks=None,
available_functions=None,
):
"""
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
Mock LLM call that returns a predefined response.
Properly formats messages to match OpenAI's expected structure.
"""
self.call_count += 1
# If input is a string, convert to proper message format
if isinstance(messages, str):
messages = [{"role": "user", "content": messages}]
# Ensure each message has properly formatted content
for message in messages:
if isinstance(message["content"], str):
message["content"] = [{"type": "text", "text": message["content"]}]
# Return predefined response in expected format
if "Thought:" in str(messages):
return f"Thought: I will say hi\nFinal Answer: {self.response}"
return self.response
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
"""Return False to indicate that function calling is not supported.
Returns:
True, indicating that this LLM supports function calling.
False, indicating that this LLM does not support function calling.
"""
return True
return False
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
"""Return False to indicate that stop words are not supported.
Returns:
True, indicating that this LLM supports stop words.
False, indicating that this LLM does not support stop words.
"""
return True
return False
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
4096, a typical context window size for modern LLMs.
"""
return 8192
return 4096
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_implementation():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="The answer is 42")
# Test that create_llm returns the custom LLM instance directly
result_llm = create_llm(custom_llm)
assert result_llm is custom_llm
# Test calling the custom LLM
response = result_llm.call("What is the answer to life, the universe, and everything?")
# Verify that the custom LLM was called
assert len(custom_llm.calls) > 0
response = result_llm.call(
"What is the answer to life, the universe, and everything?"
)
# Verify that the response from the custom LLM was used
assert response == "The answer is 42"
assert "42" in response
@pytest.mark.vcr(filter_headers=["authorization"])
def test_custom_llm_within_crew():
"""Test that a custom LLM implementation works with create_llm."""
custom_llm = CustomLLM(response="Hello! Nice to meet you!")
agent = Agent(
role="Say Hi",
goal="Say hi to the user",
backstory="""You just say hi to the user""",
llm=custom_llm,
)
task = Task(
description="Say hi to the user",
expected_output="A greeting to the user",
agent=agent,
)
crew = Crew(
agents=[agent],
tasks=[task],
process=Process.sequential,
)
result = crew.kickoff()
# Assert the LLM was called
assert custom_llm.call_count > 0
# Assert we got a response
assert "Hello!" in result.raw
def test_custom_llm_message_formatting():
"""Test that the custom LLM properly formats messages"""
custom_llm = CustomLLM(response="Test response")
# Test with string input
result = custom_llm.call("Test message")
assert result == "Test response"
# Test with message list
messages = [
{"role": "system", "content": "System message"},
{"role": "user", "content": "User message"},
]
result = custom_llm.call(messages)
assert result == "Test response"
class JWTAuthLLM(BaseLLM):
"""Custom LLM implementation with JWT authentication."""
def __init__(self, jwt_token: str):
super().__init__()
if not jwt_token or not isinstance(jwt_token, str):
@@ -103,7 +155,7 @@ class JWTAuthLLM(BaseLLM):
self.jwt_token = jwt_token
self.calls = []
self.stop = []
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -112,24 +164,26 @@ class JWTAuthLLM(BaseLLM):
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Record the call and return a predefined response."""
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# In a real implementation, this would use the JWT token to authenticate
# with an external service
return "Response from JWT-authenticated LLM"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported."""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported."""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size."""
return 8192
@@ -138,15 +192,15 @@ class JWTAuthLLM(BaseLLM):
def test_custom_llm_with_jwt_auth():
"""Test a custom LLM implementation with JWT authentication."""
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
# Test that create_llm returns the JWT-authenticated LLM instance directly
result_llm = create_llm(jwt_llm)
assert result_llm is jwt_llm
# Test calling the JWT-authenticated LLM
response = result_llm.call("Test message")
# Verify that the JWT-authenticated LLM was called
assert len(jwt_llm.calls) > 0
# Verify that the response from the JWT-authenticated LLM was used
@@ -158,7 +212,7 @@ def test_jwt_auth_llm_validation():
# Test with invalid JWT token (empty string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token="")
# Test with invalid JWT token (non-string)
with pytest.raises(ValueError, match="Invalid JWT token"):
JWTAuthLLM(jwt_token=None)
@@ -166,10 +220,10 @@ def test_jwt_auth_llm_validation():
class TimeoutHandlingLLM(BaseLLM):
"""Custom LLM implementation with timeout handling and retry logic."""
def __init__(self, max_retries: int = 3, timeout: int = 30):
"""Initialize the TimeoutHandlingLLM with retry and timeout settings.
Args:
max_retries: Maximum number of retry attempts.
timeout: Timeout in seconds for each API call.
@@ -180,7 +234,7 @@ class TimeoutHandlingLLM(BaseLLM):
self.calls = []
self.stop = []
self.fail_count = 0 # Number of times to simulate failure
def call(
self,
messages: Union[str, List[Dict[str, str]]],
@@ -189,28 +243,30 @@ class TimeoutHandlingLLM(BaseLLM):
available_functions: Optional[Dict[str, Any]] = None,
) -> Union[str, Any]:
"""Simulate API calls with timeout handling and retry logic.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
A response string based on whether this is the first attempt or a retry.
Raises:
TimeoutError: If all retry attempts fail.
"""
# Record the initial call
self.calls.append({
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0
})
self.calls.append(
{
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
"attempt": 0,
}
)
# Simulate retry logic
for attempt in range(self.max_retries):
# Skip the first attempt recording since we already did that above
@@ -220,7 +276,9 @@ class TimeoutHandlingLLM(BaseLLM):
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
@@ -229,45 +287,49 @@ class TimeoutHandlingLLM(BaseLLM):
else:
# This is a retry attempt (attempt > 0)
# Always record retry attempts
self.calls.append({
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions
})
self.calls.append(
{
"retry_attempt": attempt,
"messages": messages,
"tools": tools,
"callbacks": callbacks,
"available_functions": available_functions,
}
)
# Simulate a failure if fail_count > 0
if self.fail_count > 0:
self.fail_count -= 1
# If we've used all retries, raise an error
if attempt == self.max_retries - 1:
raise TimeoutError(f"LLM request failed after {self.max_retries} attempts")
raise TimeoutError(
f"LLM request failed after {self.max_retries} attempts"
)
# Otherwise, continue to the next attempt (simulating backoff)
continue
else:
# Success on retry
return "Response after retry"
def supports_function_calling(self) -> bool:
"""Return True to indicate that function calling is supported.
Returns:
True, indicating that this LLM supports function calling.
"""
return True
def supports_stop_words(self) -> bool:
"""Return True to indicate that stop words are supported.
Returns:
True, indicating that this LLM supports stop words.
"""
return True
def get_context_window_size(self) -> int:
"""Return a default context window size.
Returns:
8192, a typical context window size for modern LLMs.
"""
@@ -281,14 +343,14 @@ def test_timeout_handling_llm():
response = llm.call("Test message")
assert response == "First attempt response"
assert len(llm.calls) == 1
# Test successful retry
llm = TimeoutHandlingLLM()
llm.fail_count = 1 # Fail once, then succeed
response = llm.call("Test message")
assert response == "Response after retry"
assert len(llm.calls) == 2 # Initial call + successful retry call
# Test failure after all retries
llm = TimeoutHandlingLLM(max_retries=2)
llm.fail_count = 2 # Fail twice, which is all retries