mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-27 17:18:13 +00:00
Implement improvements based on PR feedback: enhanced error handling in agent.py, JWT token validation, and rate limiting in custom_llm_test.py
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -116,9 +116,16 @@ class Agent(BaseAgent):
|
|||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
self.llm = create_llm(self.llm)
|
try:
|
||||||
|
self.llm = create_llm(self.llm)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to initialize LLM for agent '{self.role}': {str(e)}")
|
||||||
|
|
||||||
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
|
if self.function_calling_llm and not isinstance(self.function_calling_llm, BaseLLM):
|
||||||
self.function_calling_llm = create_llm(self.function_calling_llm)
|
try:
|
||||||
|
self.function_calling_llm = create_llm(self.function_calling_llm)
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to initialize function calling LLM for agent '{self.role}': {str(e)}")
|
||||||
|
|
||||||
if not self.agent_executor:
|
if not self.agent_executor:
|
||||||
self._setup_agent_executor()
|
self._setup_agent_executor()
|
||||||
|
|||||||
@@ -151,14 +151,14 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Metrics for the LLM usage during all tasks execution.",
|
description="Metrics for the LLM usage during all tasks execution.",
|
||||||
)
|
)
|
||||||
manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
manager_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will run the agent.", default=None
|
||||||
)
|
)
|
||||||
manager_agent: Optional[BaseAgent] = Field(
|
manager_agent: Optional[BaseAgent] = Field(
|
||||||
description="Custom agent that will be used as manager.", default=None
|
description="Custom agent that will be used as manager.", default=None
|
||||||
)
|
)
|
||||||
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||||
description="Language model that will run the agent.", default=None
|
description="Language model that will be used for function calling.", default=None
|
||||||
)
|
)
|
||||||
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
config: Optional[Union[Json, Dict[str, Any]]] = Field(default=None)
|
||||||
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
id: UUID4 = Field(default_factory=uuid.uuid4, frozen=True)
|
||||||
@@ -197,7 +197,7 @@ class Crew(BaseModel):
|
|||||||
default=False,
|
default=False,
|
||||||
description="Plan the crew execution and add the plan to the crew.",
|
description="Plan the crew execution and add the plan to the crew.",
|
||||||
)
|
)
|
||||||
planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
planning_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="Language model that will run the AgentPlanner if planning is True.",
|
description="Language model that will run the AgentPlanner if planning is True.",
|
||||||
)
|
)
|
||||||
@@ -213,7 +213,7 @@ class Crew(BaseModel):
|
|||||||
default=None,
|
default=None,
|
||||||
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.",
|
||||||
)
|
)
|
||||||
chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field(
|
chat_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="LLM used to handle chatting with the crew.",
|
description="LLM used to handle chatting with the crew.",
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -66,7 +66,13 @@ class LLM(ABC):
|
|||||||
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
|
Either a new LLM instance or a DefaultLLM instance for backward compatibility.
|
||||||
"""
|
"""
|
||||||
if cls is LLM and (args or kwargs.get('model') is not None):
|
if cls is LLM and (args or kwargs.get('model') is not None):
|
||||||
from crewai.llm import DefaultLLM
|
# Import locally to avoid circular imports
|
||||||
|
# This is safe because DefaultLLM is defined later in this file
|
||||||
|
DefaultLLM = globals().get('DefaultLLM')
|
||||||
|
if DefaultLLM is None:
|
||||||
|
# If DefaultLLM is not yet defined, return a placeholder
|
||||||
|
# that will be replaced with a real DefaultLLM instance later
|
||||||
|
return object.__new__(cls)
|
||||||
return DefaultLLM(*args, **kwargs)
|
return DefaultLLM(*args, **kwargs)
|
||||||
return super().__new__(cls)
|
return super().__new__(cls)
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,8 @@
|
|||||||
|
from collections import deque
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
import time
|
||||||
|
|
||||||
|
import jwt
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from crewai.llm import LLM
|
from crewai.llm import LLM
|
||||||
@@ -94,16 +97,65 @@ def test_custom_llm_implementation():
|
|||||||
|
|
||||||
|
|
||||||
class JWTAuthLLM(LLM):
|
class JWTAuthLLM(LLM):
|
||||||
"""Custom LLM implementation with JWT authentication."""
|
"""Custom LLM implementation with JWT authentication.
|
||||||
|
|
||||||
def __init__(self, jwt_token: str):
|
This class demonstrates how to implement a custom LLM that uses JWT
|
||||||
|
authentication instead of API key-based authentication. It validates
|
||||||
|
the JWT token before each call and checks for token expiration.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, jwt_token: str, expiration_buffer: int = 60):
|
||||||
|
"""Initialize the JWTAuthLLM with a JWT token.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
jwt_token: The JWT token to use for authentication.
|
||||||
|
expiration_buffer: Buffer time in seconds to warn about token expiration.
|
||||||
|
Default is 60 seconds.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the JWT token is invalid or missing.
|
||||||
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
if not jwt_token or not isinstance(jwt_token, str):
|
if not jwt_token or not isinstance(jwt_token, str):
|
||||||
raise ValueError("Invalid JWT token")
|
raise ValueError("Invalid JWT token")
|
||||||
|
|
||||||
self.jwt_token = jwt_token
|
self.jwt_token = jwt_token
|
||||||
|
self.expiration_buffer = expiration_buffer
|
||||||
self.calls = []
|
self.calls = []
|
||||||
self.stop = []
|
self.stop = []
|
||||||
|
|
||||||
|
# Validate the token immediately
|
||||||
|
self._validate_token()
|
||||||
|
|
||||||
|
def _validate_token(self) -> None:
|
||||||
|
"""Validate the JWT token.
|
||||||
|
|
||||||
|
Checks if the token is valid and not expired. Also warns if the token
|
||||||
|
is about to expire within the expiration_buffer time.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the token is invalid, expired, or malformed.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
# Decode without verification to check expiration
|
||||||
|
# In a real implementation, you would verify the signature
|
||||||
|
decoded = jwt.decode(self.jwt_token, options={"verify_signature": False})
|
||||||
|
|
||||||
|
# Check if token is expired or about to expire
|
||||||
|
if 'exp' in decoded:
|
||||||
|
expiration_time = decoded['exp']
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
if expiration_time < current_time:
|
||||||
|
raise ValueError("JWT token has expired")
|
||||||
|
|
||||||
|
if expiration_time < current_time + self.expiration_buffer:
|
||||||
|
# Token will expire soon, log a warning
|
||||||
|
import logging
|
||||||
|
logging.warning(f"JWT token will expire in {expiration_time - current_time} seconds")
|
||||||
|
except jwt.PyJWTError as e:
|
||||||
|
raise ValueError(f"Invalid JWT token format: {str(e)}")
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
@@ -111,13 +163,34 @@ class JWTAuthLLM(LLM):
|
|||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> Union[str, Any]:
|
) -> Union[str, Any]:
|
||||||
"""Record the call and return a predefined response."""
|
"""Call the LLM with JWT authentication.
|
||||||
|
|
||||||
|
Validates the JWT token before making the call to ensure it's still valid.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LLM response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the JWT token is invalid or expired.
|
||||||
|
TimeoutError: If the request times out.
|
||||||
|
ConnectionError: If there's a connection issue.
|
||||||
|
"""
|
||||||
|
# Validate token before making the call
|
||||||
|
self._validate_token()
|
||||||
|
|
||||||
self.calls.append({
|
self.calls.append({
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
"callbacks": callbacks,
|
"callbacks": callbacks,
|
||||||
"available_functions": available_functions
|
"available_functions": available_functions
|
||||||
})
|
})
|
||||||
|
|
||||||
# In a real implementation, this would use the JWT token to authenticate
|
# In a real implementation, this would use the JWT token to authenticate
|
||||||
# with an external service
|
# with an external service
|
||||||
return "Response from JWT-authenticated LLM"
|
return "Response from JWT-authenticated LLM"
|
||||||
@@ -137,7 +210,14 @@ class JWTAuthLLM(LLM):
|
|||||||
|
|
||||||
def test_custom_llm_with_jwt_auth():
|
def test_custom_llm_with_jwt_auth():
|
||||||
"""Test a custom LLM implementation with JWT authentication."""
|
"""Test a custom LLM implementation with JWT authentication."""
|
||||||
jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token")
|
# Create a valid JWT token that expires 1 hour from now
|
||||||
|
valid_token = jwt.encode(
|
||||||
|
{"exp": int(time.time()) + 3600},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||||
|
|
||||||
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
# Test that create_llm returns the JWT-authenticated LLM instance directly
|
||||||
result_llm = create_llm(jwt_llm)
|
result_llm = create_llm(jwt_llm)
|
||||||
@@ -163,6 +243,31 @@ def test_jwt_auth_llm_validation():
|
|||||||
with pytest.raises(ValueError, match="Invalid JWT token"):
|
with pytest.raises(ValueError, match="Invalid JWT token"):
|
||||||
JWTAuthLLM(jwt_token=None)
|
JWTAuthLLM(jwt_token=None)
|
||||||
|
|
||||||
|
# Test with expired token
|
||||||
|
# Create a token that expired 1 hour ago
|
||||||
|
expired_token = jwt.encode(
|
||||||
|
{"exp": int(time.time()) - 3600},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
with pytest.raises(ValueError, match="JWT token has expired"):
|
||||||
|
JWTAuthLLM(jwt_token=expired_token)
|
||||||
|
|
||||||
|
# Test with malformed token
|
||||||
|
with pytest.raises(ValueError, match="Invalid JWT token format"):
|
||||||
|
JWTAuthLLM(jwt_token="not.a.valid.jwt.token")
|
||||||
|
|
||||||
|
# Test with valid token
|
||||||
|
# Create a token that expires 1 hour from now
|
||||||
|
valid_token = jwt.encode(
|
||||||
|
{"exp": int(time.time()) + 3600},
|
||||||
|
"secret",
|
||||||
|
algorithm="HS256"
|
||||||
|
)
|
||||||
|
# This should not raise an exception
|
||||||
|
jwt_llm = JWTAuthLLM(jwt_token=valid_token)
|
||||||
|
assert jwt_llm.jwt_token == valid_token
|
||||||
|
|
||||||
|
|
||||||
class TimeoutHandlingLLM(LLM):
|
class TimeoutHandlingLLM(LLM):
|
||||||
"""Custom LLM implementation with timeout handling and retry logic."""
|
"""Custom LLM implementation with timeout handling and retry logic."""
|
||||||
@@ -295,3 +400,171 @@ def test_timeout_handling_llm():
|
|||||||
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"):
|
||||||
llm.call("Test message")
|
llm.call("Test message")
|
||||||
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
assert len(llm.calls) == 2 # Initial call + failed retry attempt
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limited_llm():
|
||||||
|
"""Test that rate limiting works correctly."""
|
||||||
|
# Create a rate limited LLM with a very low limit (2 requests per minute)
|
||||||
|
llm = RateLimitedLLM(requests_per_minute=2)
|
||||||
|
|
||||||
|
# First request should succeed
|
||||||
|
response1 = llm.call("Test message 1")
|
||||||
|
assert response1 == "Rate limited response"
|
||||||
|
assert len(llm.calls) == 1
|
||||||
|
|
||||||
|
# Second request should succeed
|
||||||
|
response2 = llm.call("Test message 2")
|
||||||
|
assert response2 == "Rate limited response"
|
||||||
|
assert len(llm.calls) == 2
|
||||||
|
|
||||||
|
# Third request should fail due to rate limiting
|
||||||
|
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||||
|
llm.call("Test message 3")
|
||||||
|
|
||||||
|
# Test with invalid requests_per_minute
|
||||||
|
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||||
|
RateLimitedLLM(requests_per_minute=0)
|
||||||
|
|
||||||
|
with pytest.raises(ValueError, match="requests_per_minute must be a positive integer"):
|
||||||
|
RateLimitedLLM(requests_per_minute=-1)
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_reset():
|
||||||
|
"""Test that rate limits reset after the time window passes."""
|
||||||
|
# Create a rate limited LLM with a very low limit (1 request per minute)
|
||||||
|
# and a short time window for testing (1 second instead of 60 seconds)
|
||||||
|
time_window = 1 # 1 second instead of 60 seconds
|
||||||
|
llm = RateLimitedLLM(requests_per_minute=1, time_window=time_window)
|
||||||
|
|
||||||
|
# First request should succeed
|
||||||
|
response1 = llm.call("Test message 1")
|
||||||
|
assert response1 == "Rate limited response"
|
||||||
|
|
||||||
|
# Second request should fail due to rate limiting
|
||||||
|
with pytest.raises(ValueError, match="Rate limit exceeded"):
|
||||||
|
llm.call("Test message 2")
|
||||||
|
|
||||||
|
# Wait for the rate limit to reset
|
||||||
|
import time
|
||||||
|
time.sleep(time_window + 0.1) # Add a small buffer
|
||||||
|
|
||||||
|
# After waiting, we should be able to make another request
|
||||||
|
response3 = llm.call("Test message 3")
|
||||||
|
assert response3 == "Rate limited response"
|
||||||
|
assert len(llm.calls) == 2 # First and third requests
|
||||||
|
|
||||||
|
|
||||||
|
class RateLimitedLLM(LLM):
|
||||||
|
"""Custom LLM implementation with rate limiting.
|
||||||
|
|
||||||
|
This class demonstrates how to implement a custom LLM with rate limiting
|
||||||
|
capabilities. It uses a sliding window algorithm to ensure that no more
|
||||||
|
than a specified number of requests are made within a given time period.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, requests_per_minute: int = 60, base_response: str = "Rate limited response", time_window: int = 60):
|
||||||
|
"""Initialize the RateLimitedLLM with rate limiting parameters.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
requests_per_minute: Maximum number of requests allowed per minute.
|
||||||
|
base_response: Default response to return.
|
||||||
|
time_window: Time window in seconds for rate limiting (default: 60).
|
||||||
|
This is configurable for testing purposes.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If requests_per_minute is not a positive integer.
|
||||||
|
"""
|
||||||
|
super().__init__()
|
||||||
|
if not isinstance(requests_per_minute, int) or requests_per_minute <= 0:
|
||||||
|
raise ValueError("requests_per_minute must be a positive integer")
|
||||||
|
|
||||||
|
self.requests_per_minute = requests_per_minute
|
||||||
|
self.base_response = base_response
|
||||||
|
self.time_window = time_window
|
||||||
|
self.request_times = deque()
|
||||||
|
self.calls = []
|
||||||
|
self.stop = []
|
||||||
|
|
||||||
|
def _check_rate_limit(self) -> None:
|
||||||
|
"""Check if the current request exceeds the rate limit.
|
||||||
|
|
||||||
|
This method implements a sliding window rate limiting algorithm.
|
||||||
|
It keeps track of request timestamps and ensures that no more than
|
||||||
|
`requests_per_minute` requests are made within the configured time window.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the rate limit is exceeded.
|
||||||
|
"""
|
||||||
|
current_time = time.time()
|
||||||
|
|
||||||
|
# Remove requests older than the time window
|
||||||
|
while self.request_times and current_time - self.request_times[0] > self.time_window:
|
||||||
|
self.request_times.popleft()
|
||||||
|
|
||||||
|
# Check if we've exceeded the rate limit
|
||||||
|
if len(self.request_times) >= self.requests_per_minute:
|
||||||
|
wait_time = self.time_window - (current_time - self.request_times[0])
|
||||||
|
raise ValueError(
|
||||||
|
f"Rate limit exceeded. Maximum {self.requests_per_minute} "
|
||||||
|
f"requests per {self.time_window} seconds. Try again in {wait_time:.2f} seconds."
|
||||||
|
)
|
||||||
|
|
||||||
|
# Record this request
|
||||||
|
self.request_times.append(current_time)
|
||||||
|
|
||||||
|
def call(
|
||||||
|
self,
|
||||||
|
messages: Union[str, List[Dict[str, str]]],
|
||||||
|
tools: Optional[List[dict]] = None,
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> Union[str, Any]:
|
||||||
|
"""Call the LLM with rate limiting.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The LLM response.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If the rate limit is exceeded.
|
||||||
|
"""
|
||||||
|
# Check rate limit before making the call
|
||||||
|
self._check_rate_limit()
|
||||||
|
|
||||||
|
self.calls.append({
|
||||||
|
"messages": messages,
|
||||||
|
"tools": tools,
|
||||||
|
"callbacks": callbacks,
|
||||||
|
"available_functions": available_functions
|
||||||
|
})
|
||||||
|
|
||||||
|
return self.base_response
|
||||||
|
|
||||||
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Return True to indicate that function calling is supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, indicating that this LLM supports function calling.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def supports_stop_words(self) -> bool:
|
||||||
|
"""Return True to indicate that stop words are supported.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
True, indicating that this LLM supports stop words.
|
||||||
|
"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def get_context_window_size(self) -> int:
|
||||||
|
"""Return a default context window size.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
8192, a typical context window size for modern LLMs.
|
||||||
|
"""
|
||||||
|
return 8192
|
||||||
|
|||||||
Reference in New Issue
Block a user