diff --git a/docs/custom_llm.md b/docs/custom_llm.md index 245aa3ca4..3d0fdc0c4 100644 --- a/docs/custom_llm.md +++ b/docs/custom_llm.md @@ -22,6 +22,10 @@ from typing import Any, Dict, List, Optional, Union class CustomLLM(BaseLLM): def __init__(self, api_key: str, endpoint: str): super().__init__() # Initialize the base class to set default attributes + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") self.api_key = api_key self.endpoint = endpoint self.stop = [] # You can customize stop words if needed @@ -33,40 +37,194 @@ class CustomLLM(BaseLLM): callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ # Implement your own logic to call the LLM # For example, using requests: import requests - headers = { - "Authorization": f"Bearer {self.api_key}", - "Content-Type": "application/json" - } - - # Convert string message to proper format if needed - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - - data = { - "messages": messages, - "tools": tools - } - - response = requests.post(self.endpoint, headers=headers, json=data) - return response.json()["choices"][0]["message"]["content"] + try: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ # Return True if your LLM supports function calling return True def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ # Return True if your LLM supports stop words return True def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ # Return the context window size of your LLM return 8192 ``` +## Error Handling Best Practices + +When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices: + +### 1. Implement Try-Except Blocks for API Calls + +Always wrap API calls in try-except blocks to handle different types of errors: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + # API call implementation + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") +``` + +### 2. Implement Retry Logic for Transient Failures + +For transient failures like network issues or rate limiting, implement retry logic with exponential backoff: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import time + + max_retries = 3 + retry_delay = 1 # seconds + + for attempt in range(max_retries): + try: + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except (requests.Timeout, requests.ConnectionError) as e: + if attempt < max_retries - 1: + time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff + continue + raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") +``` + +### 3. Validate Input Parameters + +Always validate input parameters to prevent runtime errors: + +```python +def __init__(self, api_key: str, endpoint: str): + super().__init__() + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.api_key = api_key + self.endpoint = endpoint +``` + +### 4. Handle Authentication Errors Gracefully + +Provide clear error messages for authentication failures: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + response = requests.post(self.endpoint, headers=self.headers, json=data) + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid API key or token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + response.raise_for_status() + # Process response + except Exception as e: + # Handle error + raise +``` + ## Example: JWT-based Authentication For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: @@ -78,6 +236,10 @@ from typing import Any, Dict, List, Optional, Union class JWTAuthLLM(BaseLLM): def __init__(self, jwt_token: str, endpoint: str): super().__init__() # Initialize the base class to set default attributes + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") self.jwt_token = jwt_token self.endpoint = endpoint self.stop = [] # You can customize stop words if needed @@ -89,9 +251,282 @@ class JWTAuthLLM(BaseLLM): callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: + """Call the LLM with JWT authentication. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ # Implement your own logic to call the LLM with JWT authentication import requests + try: + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid JWT token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + return True + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + return True + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + return 8192 +``` + +## Troubleshooting + +Here are some common issues you might encounter when implementing custom LLMs and how to resolve them: + +### 1. Authentication Failures + +**Symptoms**: 401 Unauthorized or 403 Forbidden errors + +**Solutions**: +- Verify that your API key or JWT token is valid and not expired +- Check that you're using the correct authentication header format +- Ensure that your token has the necessary permissions + +### 2. Timeout Issues + +**Symptoms**: Requests taking too long or timing out + +**Solutions**: +- Implement timeout handling as shown in the examples +- Use retry logic with exponential backoff +- Consider using a more reliable network connection + +### 3. Response Parsing Errors + +**Symptoms**: KeyError, IndexError, or ValueError when processing responses + +**Solutions**: +- Validate the response format before accessing nested fields +- Implement proper error handling for malformed responses +- Check the API documentation for the expected response format + +### 4. Rate Limiting + +**Symptoms**: 429 Too Many Requests errors + +**Solutions**: +- Implement rate limiting in your custom LLM +- Add exponential backoff for retries +- Consider using a token bucket algorithm for more precise rate control + +## Advanced Features + +### Logging + +Adding logging to your custom LLM can help with debugging and monitoring: + +```python +import logging +from typing import Any, Dict, List, Optional, Union + +class LoggingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.logger = logging.getLogger("crewai.llm.custom") + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages") + try: + # API call implementation + response = self._make_api_call(messages, tools) + self.logger.debug(f"LLM response received: {response[:100]}...") + return response + except Exception as e: + self.logger.error(f"LLM call failed: {str(e)}") + raise +``` + +### Rate Limiting + +Implementing rate limiting can help avoid overwhelming the LLM API: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class RateLimitedLLM(BaseLLM): + def __init__( + self, + api_key: str, + endpoint: str, + requests_per_minute: int = 60 + ): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.requests_per_minute = requests_per_minute + self.request_times: List[float] = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self._enforce_rate_limit() + # Record this request time + self.request_times.append(time.time()) + # Make the actual API call + return self._make_api_call(messages, tools) + + def _enforce_rate_limit(self) -> None: + """Enforce the rate limit by waiting if necessary.""" + now = time.time() + # Remove request times older than 1 minute + self.request_times = [t for t in self.request_times if now - t < 60] + + if len(self.request_times) >= self.requests_per_minute: + # Calculate how long to wait + oldest_request = min(self.request_times) + wait_time = 60 - (now - oldest_request) + if wait_time > 0: + time.sleep(wait_time) +``` + +### Metrics Collection + +Collecting metrics can help you monitor your LLM usage: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class MetricsCollectingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.metrics: Dict[str, Any] = { + "total_calls": 0, + "total_tokens": 0, + "errors": 0, + "latency": [] + } + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + start_time = time.time() + self.metrics["total_calls"] += 1 + + try: + response = self._make_api_call(messages, tools) + # Estimate tokens (simplified) + if isinstance(messages, str): + token_estimate = len(messages) // 4 + else: + token_estimate = sum(len(m.get("content", "")) // 4 for m in messages) + self.metrics["total_tokens"] += token_estimate + return response + except Exception as e: + self.metrics["errors"] += 1 + raise + finally: + latency = time.time() - start_time + self.metrics["latency"].append(latency) + + def get_metrics(self) -> Dict[str, Any]: + """Return the collected metrics.""" + avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0 + return { + **self.metrics, + "avg_latency": avg_latency + } +``` + +## Advanced Usage: Function Calling + +If your LLM supports function calling, you can implement the function calling logic in your custom LLM: + +```python +import json +from typing import Any, Dict, List, Optional, Union + +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import requests + + try: headers = { "Authorization": f"Bearer {self.jwt_token}", "Content-Type": "application/json" @@ -106,20 +541,47 @@ class JWTAuthLLM(BaseLLM): "tools": tools } - response = requests.post(self.endpoint, headers=headers, json=data) - return response.json()["choices"][0]["message"]["content"] + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 + ) + response.raise_for_status() + response_data = response.json() - def supports_function_calling(self) -> bool: - # Return True if your LLM supports function calling - return True + # Check if the LLM wants to call a function + if response_data["choices"][0]["message"].get("tool_calls"): + tool_calls = response_data["choices"][0]["message"]["tool_calls"] + + # Process each tool call + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + if available_functions and function_name in available_functions: + function_to_call = available_functions[function_name] + function_response = function_to_call(**function_args) + + # Add the function response to the messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "name": function_name, + "content": str(function_response) + }) + + # Call the LLM again with the updated messages + return self.call(messages, tools, callbacks, available_functions) - def supports_stop_words(self) -> bool: - # Return True if your LLM supports stop words - return True - - def get_context_window_size(self) -> int: - # Return the context window size of your LLM - return 8192 + # Return the text response if no function call + return response_data["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") ``` ## Using Your Custom LLM with CrewAI @@ -127,6 +589,9 @@ class JWTAuthLLM(BaseLLM): Once you've implemented your custom LLM, you can use it with CrewAI agents and crews: ```python +from crewai import Agent, Task, Crew +from typing import Dict, Any + # Create your custom LLM instance jwt_llm = JWTAuthLLM( jwt_token="your.jwt.token", @@ -151,65 +616,17 @@ task = Task( # Execute the task result = agent.execute_task(task) print(result) -``` -## Advanced Usage: Function Calling +# Or use it with a crew +crew = Crew( + agents=[agent], + tasks=[task], + manager_llm=jwt_llm, # Use your custom LLM for the manager +) -If your LLM supports function calling, you can implement the function calling logic in your custom LLM: - -```python -def call( - self, - messages: Union[str, List[Dict[str, str]]], - tools: Optional[List[dict]] = None, - callbacks: Optional[List[Any]] = None, - available_functions: Optional[Dict[str, Any]] = None, -) -> Union[str, Any]: - import requests - - headers = { - "Authorization": f"Bearer {self.jwt_token}", - "Content-Type": "application/json" - } - - # Convert string message to proper format if needed - if isinstance(messages, str): - messages = [{"role": "user", "content": messages}] - - data = { - "messages": messages, - "tools": tools - } - - response = requests.post(self.endpoint, headers=headers, json=data) - response_data = response.json() - - # Check if the LLM wants to call a function - if response_data["choices"][0]["message"].get("tool_calls"): - tool_calls = response_data["choices"][0]["message"]["tool_calls"] - - # Process each tool call - for tool_call in tool_calls: - function_name = tool_call["function"]["name"] - function_args = json.loads(tool_call["function"]["arguments"]) - - if available_functions and function_name in available_functions: - function_to_call = available_functions[function_name] - function_response = function_to_call(**function_args) - - # Add the function response to the messages - messages.append({ - "role": "tool", - "tool_call_id": tool_call["id"], - "name": function_name, - "content": str(function_response) - }) - - # Call the LLM again with the updated messages - return self.call(messages, tools, callbacks, available_functions) - - # Return the text response if no function call - return response_data["choices"][0]["message"]["content"] +# Run the crew +result = crew.kickoff() +print(result) ``` ## Implementing Your Own Authentication Mechanism diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index ca5b9ccb7..0d6b06961 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -4,7 +4,7 @@ from crewai.agent import Agent from crewai.crew import Crew from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge -from crewai.llm import BaseLLM, LLM +from crewai.llm import LLM, BaseLLM from crewai.process import Process from crewai.task import Task diff --git a/src/crewai/crew.py b/src/crewai/crew.py index 8cfba074a..ed3cfe4b2 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,10 +6,9 @@ import warnings from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union, cast, TypeVar +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast from langchain_core.tools import BaseTool as LangchainBaseTool -from crewai.tools.base_tool import BaseTool, Tool from pydantic import ( UUID4, BaseModel, @@ -38,7 +37,7 @@ from crewai.task import Task from crewai.tasks.conditional_task import ConditionalTask from crewai.tasks.task_output import TaskOutput from crewai.tools.agent_tools.agent_tools import AgentTools -from crewai.tools.base_tool import Tool +from crewai.tools.base_tool import BaseTool, Tool from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINING_DATA_FILE diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py index b833a57a0..c3b0de1c0 100644 --- a/tests/custom_llm_test.py +++ b/tests/custom_llm_test.py @@ -7,9 +7,19 @@ from crewai.utilities.llm_utils import create_llm class CustomLLM(BaseLLM): - """Custom LLM implementation for testing.""" + """Custom LLM implementation for testing. + + This is a simple implementation of the BaseLLM abstract base class + that returns a predefined response for testing purposes. + """ def __init__(self, response: str = "Custom LLM response"): + """Initialize the CustomLLM with a predefined response. + + Args: + response: The predefined response to return from call(). + """ + super().__init__() self.response = response self.calls = [] self.stop = [] @@ -21,7 +31,17 @@ class CustomLLM(BaseLLM): callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: - """Record the call and return the predefined response.""" + """Record the call and return the predefined response. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The predefined response string. + """ self.calls.append({ "messages": messages, "tools": tools, @@ -31,15 +51,27 @@ class CustomLLM(BaseLLM): return self.response def supports_function_calling(self) -> bool: - """Return True to indicate that function calling is supported.""" + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ return True def supports_stop_words(self) -> bool: - """Return True to indicate that stop words are supported.""" + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ return True def get_context_window_size(self) -> int: - """Return a default context window size.""" + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ return 8192 @@ -119,3 +151,147 @@ def test_custom_llm_with_jwt_auth(): assert len(jwt_llm.calls) > 0 # Verify that the response from the JWT-authenticated LLM was used assert response == "Response from JWT-authenticated LLM" + + +def test_jwt_auth_llm_validation(): + """Test that JWT token validation works correctly.""" + # Test with invalid JWT token (empty string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token="") + + # Test with invalid JWT token (non-string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token=None) + + +class TimeoutHandlingLLM(BaseLLM): + """Custom LLM implementation with timeout handling and retry logic.""" + + def __init__(self, max_retries: int = 3, timeout: int = 30): + """Initialize the TimeoutHandlingLLM with retry and timeout settings. + + Args: + max_retries: Maximum number of retry attempts. + timeout: Timeout in seconds for each API call. + """ + super().__init__() + self.max_retries = max_retries + self.timeout = timeout + self.calls = [] + self.stop = [] + self.fail_count = 0 # Number of times to simulate failure + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Simulate API calls with timeout handling and retry logic. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + A response string based on whether this is the first attempt or a retry. + + Raises: + TimeoutError: If all retry attempts fail. + """ + # Record the initial call + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "attempt": 0 + }) + + # Simulate retry logic + for attempt in range(self.max_retries): + # Skip the first attempt recording since we already did that above + if attempt == 0: + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on first attempt + return "First attempt response" + else: + # This is a retry attempt (attempt > 0) + # Always record retry attempts + self.calls.append({ + "retry_attempt": attempt, + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on retry + return "Response after retry" + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192 + + +def test_timeout_handling_llm(): + """Test a custom LLM implementation with timeout handling and retry logic.""" + # Test successful first attempt + llm = TimeoutHandlingLLM() + response = llm.call("Test message") + assert response == "First attempt response" + assert len(llm.calls) == 1 + + # Test successful retry + llm = TimeoutHandlingLLM() + llm.fail_count = 1 # Fail once, then succeed + response = llm.call("Test message") + assert response == "Response after retry" + assert len(llm.calls) == 2 # Initial call + successful retry call + + # Test failure after all retries + llm = TimeoutHandlingLLM(max_retries=2) + llm.fail_count = 2 # Fail twice, which is all retries + with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"): + llm.call("Test message") + assert len(llm.calls) == 2 # Initial call + failed retry attempt