diff --git a/src/crewai/agent.py b/src/crewai/agent.py index f0ee25718..41d514ad6 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -116,9 +116,16 @@ class Agent(BaseAgent): def post_init_setup(self): self.agent_ops_agent_name = self.role - self.llm = create_llm(self.llm) + try: + self.llm = create_llm(self.llm) + except Exception as e: + raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}") + if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM): - self.function_calling_llm = create_llm(self.function_calling_llm) + try: + self.function_calling_llm = create_llm(self.function_calling_llm) + except Exception as e: + raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}") if not self.agent_executor: self._setup_agent_executor() diff --git a/src/crewai/crew.py b/src/crewai/crew.py index ed3cfe4b2..e23f8d3ce 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -151,14 +151,14 @@ class Crew(BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( description="Language model that will run the agent.", default=None ) manager_agent: Optional[BaseAgent] = Field( description="Custom agent that will be used as manager.", default=None ) function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( - description="Language model that will run the agent.", default=None + description="Language model that will be used for function calling.", default=None ) config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None) id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True) @@ -197,7 +197,7 @@ class Crew(BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( default=None, description="Language model that will run the AgentPlanner if planning is True.", ) @@ -213,7 +213,7 @@ class Crew(BaseModel): default=None, description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", ) - chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( + chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( default=None, description="LLM used to handle chatting with the crew.", ) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index fa64ccd6c..d6a9977ac 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -66,7 +66,13 @@ class LLM(ABC): Either a new LLM instance or a DefaultLLM instance for backward compatibility. """ if cls is LLM and (args or kwargs.get('model') is not None): - from crewai.llm import DefaultLLM + # Import locally to avoid circular imports + # This is safe because DefaultLLM is defined later in this file + DefaultLLM = globals().get('DefaultLLM') + if DefaultLLM is None: + # If DefaultLLM is not yet defined, return a placeholder + # that will be replaced with a real DefaultLLM instance later + return object.__new__(cls) return DefaultLLM(*args, **kwargs) return super().__new__(cls) diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py index 7cef215fa..49e11970b 100644 --- a/tests/custom_llm_test.py +++ b/tests/custom_llm_test.py @@ -1,5 +1,8 @@ +from collections import deque from typing import Any, Dict, List, Optional, Union +import time +import jwt import pytest from crewai.llm import LLM @@ -94,16 +97,65 @@ def test_custom_llm_implementation(): class JWTAuthLLM(LLM): - """Custom LLM implementation with JWT authentication.""" + """Custom LLM implementation with JWT authentication. - def __init__(self, jwt_token: str): + This class demonstrates how to implement a custom LLM that uses JWT + authentication instead of API key-based authentication. It validates + the JWT token before each call and checks for token expiration. + """ + + def __init__(self, jwt_token: str, expiration_buffer: int = 60): + """Initialize the JWTAuthLLM with a JWT token. + + Args: + jwt_token: The JWT token to use for authentication. + expiration_buffer: Buffer time in seconds to warn about token expiration. + Default is 60 seconds. + + Raises: + ValueError: If the JWT token is invalid or missing. + """ super().__init__() if not jwt_token or not isinstance(jwt_token, str): raise ValueError("Invalid JWT token") + self.jwt_token = jwt_token + self.expiration_buffer = expiration_buffer self.calls = [] self.stop = [] + # Validate the token immediately + self._validate_token() + + def _validate_token(self) -> None: + """Validate the JWT token. + + Checks if the token is valid and not expired. Also warns if the token + is about to expire within the expiration_buffer time. + + Raises: + ValueError: If the token is invalid, expired, or malformed. + """ + try: + # Decode without verification to check expiration + # In a real implementation, you would verify the signature + decoded = jwt.decode(self.jwt_token, options={"verify_signature": False}) + + # Check if token is expired or about to expire + if 'exp' in decoded: + expiration_time = decoded['exp'] + current_time = time.time() + + if expiration_time < current_time: + raise ValueError("JWT token has expired") + + if expiration_time < current_time + self.expiration_buffer: + # Token will expire soon, log a warning + import logging + logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds") + except jwt.PyJWTError as e: + raise ValueError(f"Invalid JWT token format: {str(e)}") + def call( self, messages: Union[str, List[Dict[str, str]]], @@ -111,13 +163,34 @@ class JWTAuthLLM(LLM): callbacks: Optional[List[Any]] = None, available_functions: Optional[Dict[str, Any]] = None, ) -> Union[str, Any]: - """Record the call and return a predefined response.""" + """Call the LLM with JWT authentication. + + Validates the JWT token before making the call to ensure it's still valid. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The LLM response. + + Raises: + ValueError: If the JWT token is invalid or expired. + TimeoutError: If the request times out. + ConnectionError: If there's a connection issue. + """ + # Validate token before making the call + self._validate_token() + self.calls.append({ "messages": messages, "tools": tools, "callbacks": callbacks, "available_functions": available_functions }) + # In a real implementation, this would use the JWT token to authenticate # with an external service return "Response from JWT-authenticated LLM" @@ -137,7 +210,14 @@ class JWTAuthLLM(LLM): def test_custom_llm_with_jwt_auth(): """Test a custom LLM implementation with JWT authentication.""" - jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") + # Create a valid JWT token that expires 1 hour from now + valid_token = jwt.encode( + {"exp": int(time.time()) + 3600}, + "secret", + algorithm="HS256" + ) + + jwt_llm = JWTAuthLLM(jwt_token=valid_token) # Test that create_llm returns the JWT-authenticated LLM instance directly result_llm = create_llm(jwt_llm) @@ -162,6 +242,31 @@ def test_jwt_auth_llm_validation(): # Test with invalid JWT token (non-string) with pytest.raises(ValueError, match="Invalid JWT token"): JWTAuthLLM(jwt_token=None) + + # Test with expired token + # Create a token that expired 1 hour ago + expired_token = jwt.encode( + {"exp": int(time.time()) - 3600}, + "secret", + algorithm="HS256" + ) + with pytest.raises(ValueError, match="JWT token has expired"): + JWTAuthLLM(jwt_token=expired_token) + + # Test with malformed token + with pytest.raises(ValueError, match="Invalid JWT token format"): + JWTAuthLLM(jwt_token="not.a.valid.jwt.token") + + # Test with valid token + # Create a token that expires 1 hour from now + valid_token = jwt.encode( + {"exp": int(time.time()) + 3600}, + "secret", + algorithm="HS256" + ) + # This should not raise an exception + jwt_llm = JWTAuthLLM(jwt_token=valid_token) + assert jwt_llm.jwt_token == valid_token class TimeoutHandlingLLM(LLM): @@ -295,3 +400,171 @@ def test_timeout_handling_llm(): with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"): llm.call("Test message") assert len(llm.calls) == 2 # Initial call + failed retry attempt + + +def test_rate_limited_llm(): + """Test that rate limiting works correctly.""" + # Create a rate limited LLM with a very low limit (2 requests per minute) + llm = RateLimitedLLM(requests_per_minute=2) + + # First request should succeed + response1 = llm.call("Test message 1") + assert response1 == "Rate limited response" + assert len(llm.calls) == 1 + + # Second request should succeed + response2 = llm.call("Test message 2") + assert response2 == "Rate limited response" + assert len(llm.calls) == 2 + + # Third request should fail due to rate limiting + with pytest.raises(ValueError, match="Rate limit exceeded"): + llm.call("Test message 3") + + # Test with invalid requests_per_minute + with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): + RateLimitedLLM(requests_per_minute=0) + + with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"): + RateLimitedLLM(requests_per_minute=-1) + + +def test_rate_limit_reset(): + """Test that rate limits reset after the time window passes.""" + # Create a rate limited LLM with a very low limit (1 request per minute) + # and a short time window for testing (1 second instead of 60 seconds) + time_window = 1 # 1 second instead of 60 seconds + llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window) + + # First request should succeed + response1 = llm.call("Test message 1") + assert response1 == "Rate limited response" + + # Second request should fail due to rate limiting + with pytest.raises(ValueError, match="Rate limit exceeded"): + llm.call("Test message 2") + + # Wait for the rate limit to reset + import time + time.sleep(time_window + 0.1) # Add a small buffer + + # After waiting, we should be able to make another request + response3 = llm.call("Test message 3") + assert response3 == "Rate limited response" + assert len(llm.calls) == 2 # First and third requests + + +class RateLimitedLLM(LLM): + """Custom LLM implementation with rate limiting. + + This class demonstrates how to implement a custom LLM with rate limiting + capabilities. It uses a sliding window algorithm to ensure that no more + than a specified number of requests are made within a given time period. + """ + + def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60): + """Initialize the RateLimitedLLM with rate limiting parameters. + + Args: + requests_per_minute: Maximum number of requests allowed per minute. + base_response: Default response to return. + time_window: Time window in seconds for rate limiting (default: 60). + This is configurable for testing purposes. + + Raises: + ValueError: If requests_per_minute is not a positive integer. + """ + super().__init__() + if not isinstance(requests_per_minute, int) or requests_per_minute <= 0: + raise ValueError("requests_per_minute must be a positive integer") + + self.requests_per_minute = requests_per_minute + self.base_response = base_response + self.time_window = time_window + self.request_times = deque() + self.calls = [] + self.stop = [] + + def _check_rate_limit(self) -> None: + """Check if the current request exceeds the rate limit. + + This method implements a sliding window rate limiting algorithm. + It keeps track of request timestamps and ensures that no more than + `requests_per_minute` requests are made within the configured time window. + + Raises: + ValueError: If the rate limit is exceeded. + """ + current_time = time.time() + + # Remove requests older than the time window + while self.request_times and current_time - self.request_times[0] > self.time_window: + self.request_times.popleft() + + # Check if we've exceeded the rate limit + if len(self.request_times) >= self.requests_per_minute: + wait_time = self.time_window - (current_time - self.request_times[0]) + raise ValueError( + f"Rate limit exceeded. Maximum {self.requests_per_minute} " + f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds." + ) + + # Record this request + self.request_times.append(current_time) + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with rate limiting. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The LLM response. + + Raises: + ValueError: If the rate limit is exceeded. + """ + # Check rate limit before making the call + self._check_rate_limit() + + self.calls.append({ + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions + }) + + return self.base_response + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192