from collections import deque from typing import Any, Dict, List, Optional, Union import time import jwt import pytest from crewai.llm import LLM from crewai.utilities.llm_utils import create_llm class CustomLLM(LLM): """Custom LLM implementation for testing. This is a simple implementation of the LLM abstract base class that returns a predefined response for testing purposes. """ def __init__(self, response: str = "Custom LLM response"): """Initialize the CustomLLM with a predefined response. Args: response: The predefined response to return from call(). """ super().__init__() self.response = response self.calls = [] self.stop = [] def call( self, messages: Union[str, List[Dict[str, str]]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: """Record the call and return the predefined response. Args: messages: Input messages for the LLM. tools: Optional list of tool schemas for function calling. callbacks: Optional list of callback functions. available_functions: Optional dict mapping function names to callables. Returns: The predefined response string. """ self.calls.append({ "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions }) return self.response def supports_function_calling(self) -> bool: """Return True to indicate that function calling is supported. Returns: True, indicating that this LLM supports function calling. """ return True def supports_stop_words(self) -> bool: """Return True to indicate that stop words are supported. Returns: True, indicating that this LLM supports stop words. """ return True def get_context_window_size(self) -> int: """Return a default context window size. Returns: 8192, a typical context window size for modern LLMs. """ return 8192 def test_custom_llm_implementation(): """Test that a custom LLM implementation works with create_llm.""" custom_llm = CustomLLM(response="The answer is 42") # Test that create_llm returns the custom LLM instance directly result_llm = create_llm(custom_llm) assert result_llm is custom_llm # Test calling the custom LLM response = result_llm.call("What is the answer to life, the universe, and everything?") # Verify that the custom LLM was called assert len(custom_llm.calls) > 0 # Verify that the response from the custom LLM was used assert response == "The answer is 42" class JWTAuthLLM(LLM): """Custom LLM implementation with JWT authentication. This class demonstrates how to implement a custom LLM that uses JWT authentication instead of API key-based authentication. It validates the JWT token before each call and checks for token expiration. """ def __init__(self, jwt_token: str, expiration_buffer: int = 60): """Initialize the JWTAuthLLM with a JWT token. Args: jwt_token: The JWT token to use for authentication. expiration_buffer: Buffer time in seconds to warn about token expiration. Default is 60 seconds. Raises: ValueError: If the JWT token is invalid or missing. """ super().__init__() if not jwt_token or not isinstance(jwt_token, str): raise ValueError("Invalid JWT token") self.jwt_token = jwt_token self.expiration_buffer = expiration_buffer self.calls = [] self.stop = [] # Validate the token immediately self._validate_token() def _validate_token(self) -> None: """Validate the JWT token. Checks if the token is valid and not expired. Also warns if the token is about to expire within the expiration_buffer time. Raises: ValueError: If the token is invalid, expired, or malformed. """ try: # Decode without verification to check expiration # In a real implementation, you would verify the signature decoded = jwt.decode(self.jwt_token, options={"verify_signature": False}) # Check if token is expired or about to expire if 'exp' in decoded: expiration_time = decoded['exp'] current_time = time.time() if expiration_time < current_time: raise ValueError("JWT token has expired") if expiration_time < current_time + self.expiration_buffer: # Token will expire soon, log a warning import logging logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds") except jwt.PyJWTError as e: raise ValueError(f"Invalid JWT token format: {str(e)}") def call( self, messages: Union[str, List[Dict[str, str]]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: """Call the LLM with JWT authentication. Validates the JWT token before making the call to ensure it's still valid. Args: messages: Input messages for the LLM. tools: Optional list of tool schemas for function calling. callbacks: Optional list of callback functions. available_functions: Optional dict mapping function names to callables. Returns: The LLM response. Raises: ValueError: If the JWT token is invalid or expired. TimeoutError: If the request times out. ConnectionError: If there's a connection issue. """ # Validate token before making the call self._validate_token() self.calls.append({ "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions }) # In a real implementation, this would use the JWT token to authenticate # with an external service return "Response from JWT-authenticated LLM" def supports_function_calling(self) -> bool: """Return True to indicate that function calling is supported.""" return True def supports_stop_words(self) -> bool: """Return True to indicate that stop words are supported.""" return True def get_context_window_size(self) -> int: """Return a default context window size.""" return 8192 def test_custom_llm_with_jwt_auth(): """Test a custom LLM implementation with JWT authentication.""" # Create a valid JWT token that expires 1 hour from now valid_token = jwt.encode( {"exp": int(time.time()) + 3600}, "secret", algorithm="HS256" ) jwt_llm = JWTAuthLLM(jwt_token=valid_token) # Test that create_llm returns the JWT-authenticated LLM instance directly result_llm = create_llm(jwt_llm) assert result_llm is jwt_llm # Test calling the JWT-authenticated LLM response = result_llm.call("Test message") # Verify that the JWT-authenticated LLM was called assert len(jwt_llm.calls) > 0 # Verify that the response from the JWT-authenticated LLM was used assert response == "Response from JWT-authenticated LLM" def test_jwt_auth_llm_validation(): """Test that JWT token validation works correctly.""" # Test with invalid JWT token (empty string) with pytest.raises(ValueError, match="Invalid JWT token"): JWTAuthLLM(jwt_token="") # Test with invalid JWT token (non-string) with pytest.raises(ValueError, match="Invalid JWT token"): JWTAuthLLM(jwt_token=None) # Test with expired token # Create a token that expired 1 hour ago expired_token = jwt.encode( {"exp": int(time.time()) - 3600}, "secret", algorithm="HS256" ) with pytest.raises(ValueError, match="JWT token has expired"): JWTAuthLLM(jwt_token=expired_token) # Test with malformed token with pytest.raises(ValueError, match="Invalid JWT token format"): JWTAuthLLM(jwt_token="not.a.valid.jwt.token") # Test with valid token # Create a token that expires 1 hour from now valid_token = jwt.encode( {"exp": int(time.time()) + 3600}, "secret", algorithm="HS256" ) # This should not raise an exception jwt_llm = JWTAuthLLM(jwt_token=valid_token) assert jwt_llm.jwt_token == valid_token class TimeoutHandlingLLM(LLM): """Custom LLM implementation with timeout handling and retry logic.""" def __init__(self, max_retries: int = 3, timeout: int = 30): """Initialize the TimeoutHandlingLLM with retry and timeout settings. Args: max_retries: Maximum number of retry attempts. timeout: Timeout in seconds for each API call. """ super().__init__() self.max_retries = max_retries self.timeout = timeout self.calls = [] self.stop = [] self.fail_count = 0 # Number of times to simulate failure def call( self, messages: Union[str, List[Dict[str, str]]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: """Simulate API calls with timeout handling and retry logic. Args: messages: Input messages for the LLM. tools: Optional list of tool schemas for function calling. callbacks: Optional list of callback functions. available_functions: Optional dict mapping function names to callables. Returns: A response string based on whether this is the first attempt or a retry. Raises: TimeoutError: If all retry attempts fail. """ # Record the initial call self.calls.append({ "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions, "attempt": 0 }) # Simulate retry logic for attempt in range(self.max_retries): # Skip the first attempt recording since we already did that above if attempt == 0: # Simulate a failure if fail_count > 0 if self.fail_count > 0: self.fail_count -= 1 # If we've used all retries, raise an error if attempt == self.max_retries - 1: raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") # Otherwise, continue to the next attempt (simulating backoff) continue else: # Success on first attempt return "First attempt response" else: # This is a retry attempt (attempt > 0) # Always record retry attempts self.calls.append({ "retry_attempt": attempt, "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions }) # Simulate a failure if fail_count > 0 if self.fail_count > 0: self.fail_count -= 1 # If we've used all retries, raise an error if attempt == self.max_retries - 1: raise TimeoutError(f"LLM request failed after {self.max_retries} attempts") # Otherwise, continue to the next attempt (simulating backoff) continue else: # Success on retry return "Response after retry" def supports_function_calling(self) -> bool: """Return True to indicate that function calling is supported. Returns: True, indicating that this LLM supports function calling. """ return True def supports_stop_words(self) -> bool: """Return True to indicate that stop words are supported. Returns: True, indicating that this LLM supports stop words. """ return True def get_context_window_size(self) -> int: """Return a default context window size. Returns: 8192, a typical context window size for modern LLMs. """ return 8192 def test_timeout_handling_llm(): """Test a custom LLM implementation with timeout handling and retry logic.""" # Test successful first attempt llm = TimeoutHandlingLLM() response = llm.call("Test message") assert response == "First attempt response" assert len(llm.calls) == 1 # Test successful retry llm = TimeoutHandlingLLM() llm.fail_count = 1 # Fail once, then succeed response = llm.call("Test message") assert response == "Response after retry" assert len(llm.calls) == 2 # Initial call + successful retry call # Test failure after all retries llm = TimeoutHandlingLLM(max_retries=2) llm.fail_count = 2 # Fail twice, which is all retries with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"): llm.call("Test message") assert len(llm.calls) == 2 # Initial call + failed retry attempt def test_rate_limited_llm(): """Test that rate limiting works correctly.""" # Create a rate limited LLM with a very low limit (2 requests per minute) llm = RateLimitedLLM(requests_per_minute=2) # First request should succeed response1 = llm.call("Test message 1") assert response1 == "Rate limited response" assert len(llm.calls) == 1 # Second request should succeed response2 = llm.call("Test message 2") assert response2 == "Rate limited response" assert len(llm.calls) == 2 # Third request should fail due to rate limiting with pytest.raises(ValueError, match="Rate limit exceeded"): llm.call("Test message 3") # Test with invalid requests_per_minute with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): RateLimitedLLM(requests_per_minute=0) with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): RateLimitedLLM(requests_per_minute=-1) def test_rate_limit_reset(): """Test that rate limits reset after the time window passes.""" # Create a rate limited LLM with a very low limit (1 request per minute) # and a short time window for testing (1 second instead of 60 seconds) time_window = 1 # 1 second instead of 60 seconds llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window) # First request should succeed response1 = llm.call("Test message 1") assert response1 == "Rate limited response" # Second request should fail due to rate limiting with pytest.raises(ValueError, match="Rate limit exceeded"): llm.call("Test message 2") # Wait for the rate limit to reset import time time.sleep(time_window + 0.1) # Add a small buffer # After waiting, we should be able to make another request response3 = llm.call("Test message 3") assert response3 == "Rate limited response" assert len(llm.calls) == 2 # First and third requests class RateLimitedLLM(LLM): """Custom LLM implementation with rate limiting. This class demonstrates how to implement a custom LLM with rate limiting capabilities. It uses a sliding window algorithm to ensure that no more than a specified number of requests are made within a given time period. """ def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60): """Initialize the RateLimitedLLM with rate limiting parameters. Args: requests_per_minute: Maximum number of requests allowed per minute. base_response: Default response to return. time_window: Time window in seconds for rate limiting (default: 60). This is configurable for testing purposes. Raises: ValueError: If requests_per_minute is not a positive integer. """ super().__init__() if not isinstance(requests_per_minute, int) or requests_per_minute <= 0: raise ValueError("requests_per_minute must be a positive integer") self.requests_per_minute = requests_per_minute self.base_response = base_response self.time_window = time_window self.request_times = deque() self.calls = [] self.stop = [] def _check_rate_limit(self) -> None: """Check if the current request exceeds the rate limit. This method implements a sliding window rate limiting algorithm. It keeps track of request timestamps and ensures that no more than `requests_per_minute` requests are made within the configured time window. Raises: ValueError: If the rate limit is exceeded. """ current_time = time.time() # Remove requests older than the time window while self.request_times and current_time - self.request_times[0] > self.time_window: self.request_times.popleft() # Check if we've exceeded the rate limit if len(self.request_times) >= self.requests_per_minute: wait_time = self.time_window - (current_time - self.request_times[0]) raise ValueError( f"Rate limit exceeded. Maximum {self.requests_per_minute} " f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds." ) # Record this request self.request_times.append(current_time) def call( self, messages: Union[str, List[Dict[str, str]]], tools: Optional[List[dict]] = None, callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: """Call the LLM with rate limiting. Args: messages: Input messages for the LLM. tools: Optional list of tool schemas for function calling. callbacks: Optional list of callback functions. available_functions: Optional dict mapping function names to callables. Returns: The LLM response. Raises: ValueError: If the rate limit is exceeded. """ # Check rate limit before making the call self._check_rate_limit() self.calls.append({ "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions }) return self.base_response def supports_function_calling(self) -> bool: """Return True to indicate that function calling is supported. Returns: True, indicating that this LLM supports function calling. """ return True def supports_stop_words(self) -> bool: """Return True to indicate that stop words are supported. Returns: True, indicating that this LLM supports stop words. """ return True def get_context_window_size(self) -> int: """Return a default context window size. Returns: 8192, a typical context window size for modern LLMs. """ return 8192