From 807c13e144c268ed2164e519d5097a84b52031f0 Mon Sep 17 00:00:00 2001 From: "devin-ai-integration[bot]" <158243242+devin-ai-integration[bot]@users.noreply.github.com> Date: Tue, 25 Mar 2025 12:39:08 -0400 Subject: [PATCH] Add support for custom LLM implementations (#2277) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * Add support for custom LLM implementations Co-Authored-By: Joe Moura * Fix import sorting and type annotations Co-Authored-By: Joe Moura * Fix linting issues with import sorting Co-Authored-By: Joe Moura * Fix type errors in crew.py by updating tool-related methods to return List[BaseTool] Co-Authored-By: Joe Moura * Enhance custom LLM implementation with better error handling, documentation, and test coverage Co-Authored-By: Joe Moura * Refactor LLM module by extracting BaseLLM to a separate file This commit moves the BaseLLM abstract base class from llm.py to a new file llms/base_llm.py to improve code organization. The changes include: - Creating a new file src/crewai/llms/base_llm.py - Moving the BaseLLM class to the new file - Updating imports in __init__.py and llm.py to reflect the new location - Updating test cases to use the new import path The refactoring maintains the existing functionality while improving the project's module structure. * Add AISuite LLM support and update dependencies - Integrate AISuite as a new third-party LLM option - Update pyproject.toml and uv.lock to include aisuite package - Modify BaseLLM to support more flexible initialization - Remove unnecessary LLM imports across multiple files - Implement AISuiteLLM with basic chat completion functionality * Update AISuiteLLM and LLM utility type handling - Modify AISuiteLLM to support more flexible input types for messages - Update type hints in AISuiteLLM to allow string or list of message dictionaries - Enhance LLM utility function to support broader LLM type annotations - Remove default `self.stop` attribute from BaseLLM initialization * Update LLM imports and type hints across multiple files - Modify imports in crew_chat.py to use LLM instead of BaseLLM - Update type hints in llm_utils.py to use LLM type - Add optional `stop` parameter to BaseLLM initialization - Refactor type handling for LLM creation and usage * Improve stop words handling in CrewAgentExecutor - Add support for handling existing stop words in LLM configuration - Ensure stop words are correctly merged and deduplicated - Update type hints to support both LLM and BaseLLM types * Remove abstract method set_callbacks from BaseLLM class * Enhance CustomLLM and JWTAuthLLM initialization with model parameter - Update CustomLLM to accept a model parameter during initialization - Modify test cases to include the new model argument - Ensure JWTAuthLLM and TimeoutHandlingLLM also utilize the model parameter in their constructors - Update type hints in create_llm function to support both LLM and BaseLLM types * Enhance create_llm function to support BaseLLM type - Update the create_llm function to accept both LLM and BaseLLM instances - Ensure compatibility with existing LLM handling logic * Update type hint for initialize_chat_llm to support BaseLLM - Modify the return type of initialize_chat_llm function to allow for both LLM and BaseLLM instances - Ensure compatibility with recent changes in create_llm function * Refactor AISuiteLLM to include tools parameter in completion methods - Update the _prepare_completion_params method to accept an optional tools parameter - Modify the chat completion method to utilize the new tools parameter for enhanced functionality - Clean up print statements for better code clarity * Remove unused tool_calls handling in AISuiteLLM chat completion method for cleaner code. * Refactor Crew class and LLM hierarchy for improved type handling and code clarity - Update Crew class methods to enhance readability with consistent formatting and type hints. - Change LLM class to inherit from BaseLLM for better structure. - Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor. - Adjust BaseLLM to provide default implementations for stop words and context window size methods. - Clean up AISuiteLLM by removing unused methods related to stop words and context window size. * Remove unused `stream` method from `BaseLLM` class to enhance code clarity and maintainability. --------- Co-authored-by: Devin AI <158243242+devin-ai-integration[bot]@users.noreply.github.com> Co-authored-by: Joe Moura Co-authored-by: Lorenze Jay Co-authored-by: João Moura Co-authored-by: Brandon Hancock (bhancock_ai) <109994880+bhancockio@users.noreply.github.com> --- docs/custom_llm.md | 642 ++++++++++++++++++ pyproject.toml | 3 + src/crewai/__init__.py | 2 + src/crewai/agent.py | 10 +- src/crewai/agents/crew_agent_executor.py | 14 +- src/crewai/cli/crew_chat.py | 4 +- src/crewai/crew.py | 118 ++-- src/crewai/llm.py | 3 +- src/crewai/llms/base_llm.py | 91 +++ src/crewai/llms/third_party/ai_suite.py | 38 ++ .../evaluators/crew_evaluator_handler.py | 6 +- src/crewai/utilities/llm_utils.py | 14 +- .../test_custom_llm_implementation.yaml | 107 +++ .../test_custom_llm_within_crew.yaml | 305 +++++++++ tests/custom_llm_test.py | 359 ++++++++++ uv.lock | 16 + 16 files changed, 1671 insertions(+), 61 deletions(-) create mode 100644 docs/custom_llm.md create mode 100644 src/crewai/llms/base_llm.py create mode 100644 src/crewai/llms/third_party/ai_suite.py create mode 100644 tests/cassettes/test_custom_llm_implementation.yaml create mode 100644 tests/cassettes/test_custom_llm_within_crew.yaml create mode 100644 tests/custom_llm_test.py diff --git a/docs/custom_llm.md b/docs/custom_llm.md new file mode 100644 index 000000000..3d0fdc0c4 --- /dev/null +++ b/docs/custom_llm.md @@ -0,0 +1,642 @@ +# Custom LLM Implementations + +CrewAI now supports custom LLM implementations through the `BaseLLM` abstract base class. This allows you to create your own LLM implementations that don't rely on litellm's authentication mechanism. + +## Using Custom LLM Implementations + +To create a custom LLM implementation, you need to: + +1. Inherit from the `BaseLLM` abstract base class +2. Implement the required methods: + - `call()`: The main method to call the LLM with messages + - `supports_function_calling()`: Whether the LLM supports function calling + - `supports_stop_words()`: Whether the LLM supports stop words + - `get_context_window_size()`: The context window size of the LLM + +## Example: Basic Custom LLM + +```python +from crewai import BaseLLM +from typing import Any, Dict, List, Optional, Union + +class CustomLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.api_key = api_key + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ + # Implement your own logic to call the LLM + # For example, using requests: + import requests + + try: + headers = { + "Authorization": f"Bearer {self.api_key}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + # Return True if your LLM supports function calling + return True + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + # Return True if your LLM supports stop words + return True + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + # Return the context window size of your LLM + return 8192 +``` + +## Error Handling Best Practices + +When implementing custom LLMs, it's important to handle errors properly to ensure robustness and reliability. Here are some best practices: + +### 1. Implement Try-Except Blocks for API Calls + +Always wrap API calls in try-except blocks to handle different types of errors: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + # API call implementation + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 # Set a reasonable timeout + ) + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") +``` + +### 2. Implement Retry Logic for Transient Failures + +For transient failures like network issues or rate limiting, implement retry logic with exponential backoff: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import time + + max_retries = 3 + retry_delay = 1 # seconds + + for attempt in range(max_retries): + try: + response = requests.post( + self.endpoint, + headers=self.headers, + json=self.prepare_payload(messages), + timeout=30 + ) + response.raise_for_status() + return response.json()["choices"][0]["message"]["content"] + except (requests.Timeout, requests.ConnectionError) as e: + if attempt < max_retries - 1: + time.sleep(retry_delay * (2 ** attempt)) # Exponential backoff + continue + raise TimeoutError(f"LLM request failed after {max_retries} attempts: {str(e)}") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") +``` + +### 3. Validate Input Parameters + +Always validate input parameters to prevent runtime errors: + +```python +def __init__(self, api_key: str, endpoint: str): + super().__init__() + if not api_key or not isinstance(api_key, str): + raise ValueError("Invalid API key: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.api_key = api_key + self.endpoint = endpoint +``` + +### 4. Handle Authentication Errors Gracefully + +Provide clear error messages for authentication failures: + +```python +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + try: + response = requests.post(self.endpoint, headers=self.headers, json=data) + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid API key or token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + response.raise_for_status() + # Process response + except Exception as e: + # Handle error + raise +``` + +## Example: JWT-based Authentication + +For services that use JWT-based authentication instead of API keys, you can implement a custom LLM like this: + +```python +from crewai import BaseLLM, Agent, Task +from typing import Any, Dict, List, Optional, Union + +class JWTAuthLLM(BaseLLM): + def __init__(self, jwt_token: str, endpoint: str): + super().__init__() # Initialize the base class to set default attributes + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token: must be a non-empty string") + if not endpoint or not isinstance(endpoint, str): + raise ValueError("Invalid endpoint URL: must be a non-empty string") + self.jwt_token = jwt_token + self.endpoint = endpoint + self.stop = [] # You can customize stop words if needed + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with JWT authentication. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + Either a text response from the LLM or the result of a tool function call. + + Raises: + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + ValueError: If the response format is invalid. + """ + # Implement your own logic to call the LLM with JWT authentication + import requests + + try: + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 # Set a reasonable timeout + ) + + if response.status_code == 401: + raise ValueError("Authentication failed: Invalid JWT token") + elif response.status_code == 403: + raise ValueError("Authorization failed: Insufficient permissions") + + response.raise_for_status() # Raise an exception for HTTP errors + return response.json()["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") + + def supports_function_calling(self) -> bool: + """Check if the LLM supports function calling. + + Returns: + True if the LLM supports function calling, False otherwise. + """ + return True + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + True if the LLM supports stop words, False otherwise. + """ + return True + + def get_context_window_size(self) -> int: + """Get the context window size of the LLM. + + Returns: + The context window size as an integer. + """ + return 8192 +``` + +## Troubleshooting + +Here are some common issues you might encounter when implementing custom LLMs and how to resolve them: + +### 1. Authentication Failures + +**Symptoms**: 401 Unauthorized or 403 Forbidden errors + +**Solutions**: +- Verify that your API key or JWT token is valid and not expired +- Check that you're using the correct authentication header format +- Ensure that your token has the necessary permissions + +### 2. Timeout Issues + +**Symptoms**: Requests taking too long or timing out + +**Solutions**: +- Implement timeout handling as shown in the examples +- Use retry logic with exponential backoff +- Consider using a more reliable network connection + +### 3. Response Parsing Errors + +**Symptoms**: KeyError, IndexError, or ValueError when processing responses + +**Solutions**: +- Validate the response format before accessing nested fields +- Implement proper error handling for malformed responses +- Check the API documentation for the expected response format + +### 4. Rate Limiting + +**Symptoms**: 429 Too Many Requests errors + +**Solutions**: +- Implement rate limiting in your custom LLM +- Add exponential backoff for retries +- Consider using a token bucket algorithm for more precise rate control + +## Advanced Features + +### Logging + +Adding logging to your custom LLM can help with debugging and monitoring: + +```python +import logging +from typing import Any, Dict, List, Optional, Union + +class LoggingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.logger = logging.getLogger("crewai.llm.custom") + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self.logger.info(f"Calling LLM with {len(messages) if isinstance(messages, list) else 1} messages") + try: + # API call implementation + response = self._make_api_call(messages, tools) + self.logger.debug(f"LLM response received: {response[:100]}...") + return response + except Exception as e: + self.logger.error(f"LLM call failed: {str(e)}") + raise +``` + +### Rate Limiting + +Implementing rate limiting can help avoid overwhelming the LLM API: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class RateLimitedLLM(BaseLLM): + def __init__( + self, + api_key: str, + endpoint: str, + requests_per_minute: int = 60 + ): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.requests_per_minute = requests_per_minute + self.request_times: List[float] = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + self._enforce_rate_limit() + # Record this request time + self.request_times.append(time.time()) + # Make the actual API call + return self._make_api_call(messages, tools) + + def _enforce_rate_limit(self) -> None: + """Enforce the rate limit by waiting if necessary.""" + now = time.time() + # Remove request times older than 1 minute + self.request_times = [t for t in self.request_times if now - t < 60] + + if len(self.request_times) >= self.requests_per_minute: + # Calculate how long to wait + oldest_request = min(self.request_times) + wait_time = 60 - (now - oldest_request) + if wait_time > 0: + time.sleep(wait_time) +``` + +### Metrics Collection + +Collecting metrics can help you monitor your LLM usage: + +```python +import time +from typing import Any, Dict, List, Optional, Union + +class MetricsCollectingLLM(BaseLLM): + def __init__(self, api_key: str, endpoint: str): + super().__init__() + self.api_key = api_key + self.endpoint = endpoint + self.metrics: Dict[str, Any] = { + "total_calls": 0, + "total_tokens": 0, + "errors": 0, + "latency": [] + } + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + start_time = time.time() + self.metrics["total_calls"] += 1 + + try: + response = self._make_api_call(messages, tools) + # Estimate tokens (simplified) + if isinstance(messages, str): + token_estimate = len(messages) // 4 + else: + token_estimate = sum(len(m.get("content", "")) // 4 for m in messages) + self.metrics["total_tokens"] += token_estimate + return response + except Exception as e: + self.metrics["errors"] += 1 + raise + finally: + latency = time.time() - start_time + self.metrics["latency"].append(latency) + + def get_metrics(self) -> Dict[str, Any]: + """Return the collected metrics.""" + avg_latency = sum(self.metrics["latency"]) / len(self.metrics["latency"]) if self.metrics["latency"] else 0 + return { + **self.metrics, + "avg_latency": avg_latency + } +``` + +## Advanced Usage: Function Calling + +If your LLM supports function calling, you can implement the function calling logic in your custom LLM: + +```python +import json +from typing import Any, Dict, List, Optional, Union + +def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, +) -> Union[str, Any]: + import requests + + try: + headers = { + "Authorization": f"Bearer {self.jwt_token}", + "Content-Type": "application/json" + } + + # Convert string message to proper format if needed + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + data = { + "messages": messages, + "tools": tools + } + + response = requests.post( + self.endpoint, + headers=headers, + json=data, + timeout=30 + ) + response.raise_for_status() + response_data = response.json() + + # Check if the LLM wants to call a function + if response_data["choices"][0]["message"].get("tool_calls"): + tool_calls = response_data["choices"][0]["message"]["tool_calls"] + + # Process each tool call + for tool_call in tool_calls: + function_name = tool_call["function"]["name"] + function_args = json.loads(tool_call["function"]["arguments"]) + + if available_functions and function_name in available_functions: + function_to_call = available_functions[function_name] + function_response = function_to_call(**function_args) + + # Add the function response to the messages + messages.append({ + "role": "tool", + "tool_call_id": tool_call["id"], + "name": function_name, + "content": str(function_response) + }) + + # Call the LLM again with the updated messages + return self.call(messages, tools, callbacks, available_functions) + + # Return the text response if no function call + return response_data["choices"][0]["message"]["content"] + except requests.Timeout: + raise TimeoutError("LLM request timed out") + except requests.RequestException as e: + raise RuntimeError(f"LLM request failed: {str(e)}") + except (KeyError, IndexError, ValueError) as e: + raise ValueError(f"Invalid response format: {str(e)}") +``` + +## Using Your Custom LLM with CrewAI + +Once you've implemented your custom LLM, you can use it with CrewAI agents and crews: + +```python +from crewai import Agent, Task, Crew +from typing import Dict, Any + +# Create your custom LLM instance +jwt_llm = JWTAuthLLM( + jwt_token="your.jwt.token", + endpoint="https://your-llm-endpoint.com/v1/chat/completions" +) + +# Use it with an agent +agent = Agent( + role="Research Assistant", + goal="Find information on a topic", + backstory="You are a research assistant tasked with finding information.", + llm=jwt_llm, +) + +# Create a task for the agent +task = Task( + description="Research the benefits of exercise", + agent=agent, + expected_output="A summary of the benefits of exercise", +) + +# Execute the task +result = agent.execute_task(task) +print(result) + +# Or use it with a crew +crew = Crew( + agents=[agent], + tasks=[task], + manager_llm=jwt_llm, # Use your custom LLM for the manager +) + +# Run the crew +result = crew.kickoff() +print(result) +``` + +## Implementing Your Own Authentication Mechanism + +The `BaseLLM` class allows you to implement any authentication mechanism you need, not just JWT or API keys. You can use: + +- OAuth tokens +- Client certificates +- Custom headers +- Session-based authentication +- Any other authentication method required by your LLM provider + +Simply implement the appropriate authentication logic in your custom LLM class. diff --git a/pyproject.toml b/pyproject.toml index 6e895be32..0d7b9068e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,9 @@ mem0 = ["mem0ai>=0.1.29"] docling = [ "docling>=2.12.0", ] +aisuite = [ + "aisuite>=0.1.10", +] [tool.uv] dev-dependencies = [ diff --git a/src/crewai/__init__.py b/src/crewai/__init__.py index 4a992ff88..67d63b82c 100644 --- a/src/crewai/__init__.py +++ b/src/crewai/__init__.py @@ -5,6 +5,7 @@ from crewai.crew import Crew from crewai.flow.flow import Flow from crewai.knowledge.knowledge import Knowledge from crewai.llm import LLM +from crewai.llms.base_llm import BaseLLM from crewai.process import Process from crewai.task import Task @@ -21,6 +22,7 @@ __all__ = [ "Process", "Task", "LLM", + "BaseLLM", "Flow", "Knowledge", ] diff --git a/src/crewai/agent.py b/src/crewai/agent.py index d10b768d4..1680f4e8e 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -11,7 +11,7 @@ from crewai.agents.crew_agent_executor import CrewAgentExecutor from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context -from crewai.llm import LLM +from crewai.llm import BaseLLM from crewai.memory.contextual.contextual_memory import ContextualMemory from crewai.security import Fingerprint from crewai.task import Task @@ -71,10 +71,10 @@ class Agent(BaseAgent): default=True, description="Use system prompt for the agent.", ) - llm: Union[str, InstanceOf[LLM], Any] = Field( + llm: Union[str, InstanceOf[BaseLLM], Any] = Field( description="Language model that will run the agent.", default=None ) - function_calling_llm: Optional[Union[str, InstanceOf[LLM], Any]] = Field( + function_calling_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) system_template: Optional[str] = Field( @@ -118,7 +118,9 @@ class Agent(BaseAgent): self.agent_ops_agent_name = self.role self.llm = create_llm(self.llm) - if self.function_calling_llm and not isinstance(self.function_calling_llm, LLM): + if self.function_calling_llm and not isinstance( + self.function_calling_llm, BaseLLM + ): self.function_calling_llm = create_llm(self.function_calling_llm) if not self.agent_executor: diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 452b343c8..bb17cd095 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -13,7 +13,7 @@ from crewai.agents.parser import ( OutputParserException, ) from crewai.agents.tools_handler import ToolsHandler -from crewai.llm import LLM +from crewai.llm import BaseLLM from crewai.tools.base_tool import BaseTool from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.utilities import I18N, Printer @@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): callbacks: List[Any] = [], ): self._i18n: I18N = I18N() - self.llm: LLM = llm + self.llm: BaseLLM = llm self.task = task self.agent = agent self.crew = crew @@ -87,8 +87,14 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): self.tool_name_to_tool_map: Dict[str, BaseTool] = { tool.name: tool for tool in self.tools } - self.stop = stop_words - self.llm.stop = list(set(self.llm.stop + self.stop)) + existing_stop = self.llm.stop or [] + self.llm.stop = list( + set( + existing_stop + self.stop + if isinstance(existing_stop, list) + else self.stop + ) + ) def invoke(self, inputs: Dict[str, str]) -> Dict[str, Any]: if "system" in self.prompt: diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index cd0da2bb8..1b4e18c78 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -14,7 +14,7 @@ from packaging import version from crewai.cli.utils import read_toml from crewai.cli.version import get_crewai_version from crewai.crew import Crew -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM from crewai.types.crew_chat import ChatInputField, ChatInputs from crewai.utilities.llm_utils import create_llm @@ -116,7 +116,7 @@ def show_loading(event: threading.Event): print() -def initialize_chat_llm(crew: Crew) -> Optional[LLM]: +def initialize_chat_llm(crew: Crew) -> Optional[LLM | BaseLLM]: """Initializes the chat LLM and handles exceptions.""" try: return create_llm(crew.chat_llm) diff --git a/src/crewai/crew.py b/src/crewai/crew.py index c4216fb61..c82ff309f 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -6,8 +6,9 @@ import warnings from concurrent.futures import Future from copy import copy as shallow_copy from hashlib import md5 -from typing import Any, Callable, Dict, List, Optional, Set, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Set, Tuple, TypeVar, Union, cast +from langchain_core.tools import BaseTool as LangchainBaseTool from pydantic import ( UUID4, BaseModel, @@ -26,7 +27,7 @@ from crewai.agents.cache import CacheHandler from crewai.crews.crew_output import CrewOutput from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM from crewai.memory.entity.entity_memory import EntityMemory from crewai.memory.long_term.long_term_memory import LongTermMemory from crewai.memory.short_term.short_term_memory import ShortTermMemory @@ -37,7 +38,7 @@ from crewai.task import Task from crewai.tasks.conditional_task import ConditionalTask from crewai.tasks.task_output import TaskOutput from crewai.tools.agent_tools.agent_tools import AgentTools -from crewai.tools.base_tool import Tool +from crewai.tools.base_tool import BaseTool, Tool from crewai.types.usage_metrics import UsageMetrics from crewai.utilities import I18N, FileHandler, Logger, RPMController from crewai.utilities.constants import TRAINING_DATA_FILE @@ -153,7 +154,7 @@ class Crew(BaseModel): default=None, description="Metrics for the LLM usage during all tasks execution.", ) - manager_llm: Optional[Any] = Field( + manager_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( description="Language model that will run the agent.", default=None ) manager_agent: Optional[BaseAgent] = Field( @@ -187,7 +188,7 @@ class Crew(BaseModel): default=None, description="Maximum number of requests per minute for the crew execution to be respected.", ) - prompt_file: str = Field( + prompt_file: Optional[str] = Field( default=None, description="Path to the prompt json file to be used for the crew.", ) @@ -199,7 +200,7 @@ class Crew(BaseModel): default=False, description="Plan the crew execution and add the plan to the crew.", ) - planning_llm: Optional[Any] = Field( + planning_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( default=None, description="Language model that will run the AgentPlanner if planning is True.", ) @@ -215,7 +216,7 @@ class Crew(BaseModel): default=None, description="Knowledge sources for the crew. Add knowledge sources to the knowledge object.", ) - chat_llm: Optional[Any] = Field( + chat_llm: Optional[Union[str, InstanceOf[BaseLLM], Any]] = Field( default=None, description="LLM used to handle chatting with the crew.", ) @@ -819,7 +820,12 @@ class Crew(BaseModel): # Determine which tools to use - task tools take precedence over agent tools tools_for_task = task.tools or agent_to_use.tools or [] - tools_for_task = self._prepare_tools(agent_to_use, task, tools_for_task) + # Prepare tools and ensure they're compatible with task execution + tools_for_task = self._prepare_tools( + agent_to_use, + task, + cast(Union[List[Tool], List[BaseTool]], tools_for_task), + ) self._log_task_start(task, agent_to_use.role) @@ -838,7 +844,7 @@ class Crew(BaseModel): future = task.execute_async( agent=agent_to_use, context=context, - tools=tools_for_task, + tools=cast(List[BaseTool], tools_for_task), ) futures.append((task, future, task_index)) else: @@ -850,7 +856,7 @@ class Crew(BaseModel): task_output = task.execute_sync( agent=agent_to_use, context=context, - tools=tools_for_task, + tools=cast(List[BaseTool], tools_for_task), ) task_outputs.append(task_output) self._process_task_result(task, task_output) @@ -888,10 +894,12 @@ class Crew(BaseModel): return None def _prepare_tools( - self, agent: BaseAgent, task: Task, tools: List[Tool] - ) -> List[Tool]: + self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: # Add delegation tools if agent allows delegation - if agent.allow_delegation: + if hasattr(agent, "allow_delegation") and getattr( + agent, "allow_delegation", False + ): if self.process == Process.hierarchical: if self.manager_agent: tools = self._update_manager_tools(task, tools) @@ -900,17 +908,24 @@ class Crew(BaseModel): "Manager agent is required for hierarchical process." ) - elif agent and agent.allow_delegation: + elif agent: tools = self._add_delegation_tools(task, tools) # Add code execution tools if agent allows code execution - if agent.allow_code_execution: + if hasattr(agent, "allow_code_execution") and getattr( + agent, "allow_code_execution", False + ): tools = self._add_code_execution_tools(agent, tools) - if agent and agent.multimodal: + if ( + agent + and hasattr(agent, "multimodal") + and getattr(agent, "multimodal", False) + ): tools = self._add_multimodal_tools(agent, tools) - return tools + # Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async + return cast(List[BaseTool], tools) def _get_agent_to_use(self, task: Task) -> Optional[BaseAgent]: if self.process == Process.hierarchical: @@ -918,11 +933,13 @@ class Crew(BaseModel): return task.agent def _merge_tools( - self, existing_tools: List[Tool], new_tools: List[Tool] - ) -> List[Tool]: + self, + existing_tools: Union[List[Tool], List[BaseTool]], + new_tools: Union[List[Tool], List[BaseTool]], + ) -> List[BaseTool]: """Merge new tools into existing tools list, avoiding duplicates by tool name.""" if not new_tools: - return existing_tools + return cast(List[BaseTool], existing_tools) # Create mapping of tool names to new tools new_tool_map = {tool.name: tool for tool in new_tools} @@ -933,23 +950,41 @@ class Crew(BaseModel): # Add all new tools tools.extend(new_tools) - return tools + return cast(List[BaseTool], tools) def _inject_delegation_tools( - self, tools: List[Tool], task_agent: BaseAgent, agents: List[BaseAgent] - ): - delegation_tools = task_agent.get_delegation_tools(agents) - return self._merge_tools(tools, delegation_tools) + self, + tools: Union[List[Tool], List[BaseTool]], + task_agent: BaseAgent, + agents: List[BaseAgent], + ) -> List[BaseTool]: + if hasattr(task_agent, "get_delegation_tools"): + delegation_tools = task_agent.get_delegation_tools(agents) + # Cast delegation_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) + return cast(List[BaseTool], tools) - def _add_multimodal_tools(self, agent: BaseAgent, tools: List[Tool]): - multimodal_tools = agent.get_multimodal_tools() - return self._merge_tools(tools, multimodal_tools) + def _add_multimodal_tools( + self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: + if hasattr(agent, "get_multimodal_tools"): + multimodal_tools = agent.get_multimodal_tools() + # Cast multimodal_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) + return cast(List[BaseTool], tools) - def _add_code_execution_tools(self, agent: BaseAgent, tools: List[Tool]): - code_tools = agent.get_code_execution_tools() - return self._merge_tools(tools, code_tools) + def _add_code_execution_tools( + self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: + if hasattr(agent, "get_code_execution_tools"): + code_tools = agent.get_code_execution_tools() + # Cast code_tools to the expected type for _merge_tools + return self._merge_tools(tools, cast(List[BaseTool], code_tools)) + return cast(List[BaseTool], tools) - def _add_delegation_tools(self, task: Task, tools: List[Tool]): + def _add_delegation_tools( + self, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: agents_for_delegation = [agent for agent in self.agents if agent != task.agent] if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if not tools: @@ -957,7 +992,7 @@ class Crew(BaseModel): tools = self._inject_delegation_tools( tools, task.agent, agents_for_delegation ) - return tools + return cast(List[BaseTool], tools) def _log_task_start(self, task: Task, role: str = "None"): if self.output_log_file: @@ -965,7 +1000,9 @@ class Crew(BaseModel): task_name=task.name, task=task.description, agent=role, status="started" ) - def _update_manager_tools(self, task: Task, tools: List[Tool]): + def _update_manager_tools( + self, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: if self.manager_agent: if task.agent: tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) @@ -973,7 +1010,7 @@ class Crew(BaseModel): tools = self._inject_delegation_tools( tools, self.manager_agent, self.agents ) - return tools + return cast(List[BaseTool], tools) def _get_context(self, task: Task, task_outputs: List[TaskOutput]): context = ( @@ -1214,13 +1251,14 @@ class Crew(BaseModel): def test( self, n_iterations: int, - eval_llm: Union[str, InstanceOf[LLM]], + eval_llm: Union[str, InstanceOf[BaseLLM]], inputs: Optional[Dict[str, Any]] = None, ) -> None: """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" try: - eval_llm = create_llm(eval_llm) - if not eval_llm: + # Create LLM instance and ensure it's of type LLM for CrewEvaluator + llm_instance = create_llm(eval_llm) + if not llm_instance: raise ValueError("Failed to create LLM instance.") crewai_event_bus.emit( @@ -1228,12 +1266,12 @@ class Crew(BaseModel): CrewTestStartedEvent( crew_name=self.name or "crew", n_iterations=n_iterations, - eval_llm=eval_llm, + eval_llm=llm_instance, inputs=inputs, ), ) test_crew = self.copy() - evaluator = CrewEvaluator(test_crew, eval_llm) # type: ignore[arg-type] + evaluator = CrewEvaluator(test_crew, llm_instance) for i in range(1, n_iterations + 1): evaluator.set_iteration(i) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 68ddbacc7..741544662 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -40,6 +40,7 @@ with warnings.catch_warnings(): from litellm.utils import supports_response_schema +from crewai.llms.base_llm import BaseLLM from crewai.utilities.events import crewai_event_bus from crewai.utilities.exceptions.context_window_exceeding_exception import ( LLMContextLengthExceededException, @@ -218,7 +219,7 @@ class StreamingChoices(TypedDict): finish_reason: Optional[str] -class LLM: +class LLM(BaseLLM): def __init__( self, model: str, diff --git a/src/crewai/llms/base_llm.py b/src/crewai/llms/base_llm.py new file mode 100644 index 000000000..c51e8847d --- /dev/null +++ b/src/crewai/llms/base_llm.py @@ -0,0 +1,91 @@ +from abc import ABC, abstractmethod +from typing import Any, Callable, Dict, List, Optional, Union + + +class BaseLLM(ABC): + """Abstract base class for LLM implementations. + + This class defines the interface that all LLM implementations must follow. + Users can extend this class to create custom LLM implementations that don't + rely on litellm's authentication mechanism. + + Custom LLM implementations should handle error cases gracefully, including + timeouts, authentication failures, and malformed responses. They should also + implement proper validation for input parameters and provide clear error + messages when things go wrong. + + Attributes: + stop (list): A list of stop sequences that the LLM should use to stop generation. + This is used by the CrewAgentExecutor and other components. + """ + + model: str + temperature: Optional[float] = None + stop: Optional[List[str]] = None + + def __init__( + self, + model: str, + temperature: Optional[float] = None, + ): + """Initialize the BaseLLM with default attributes. + + This constructor sets default values for attributes that are expected + by the CrewAgentExecutor and other components. + + All custom LLM implementations should call super().__init__() to ensure + that these default attributes are properly initialized. + """ + self.model = model + self.temperature = temperature + self.stop = [] + + @abstractmethod + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Call the LLM with the given messages. + + Args: + messages: Input messages for the LLM. + Can be a string or list of message dictionaries. + If string, it will be converted to a single user message. + If list, each dict must have 'role' and 'content' keys. + tools: Optional list of tool schemas for function calling. + Each tool should define its name, description, and parameters. + callbacks: Optional list of callback functions to be executed + during and after the LLM call. + available_functions: Optional dict mapping function names to callables + that can be invoked by the LLM. + + Returns: + Either a text response from the LLM (str) or + the result of a tool function call (Any). + + Raises: + ValueError: If the messages format is invalid. + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + """ + pass + + def supports_stop_words(self) -> bool: + """Check if the LLM supports stop words. + + Returns: + bool: True if the LLM supports stop words, False otherwise. + """ + return True # Default implementation assumes support for stop words + + def get_context_window_size(self) -> int: + """Get the context window size for the LLM. + + Returns: + int: The number of tokens/characters the model can handle. + """ + # Default implementation - subclasses should override with model-specific values + return 4096 diff --git a/src/crewai/llms/third_party/ai_suite.py b/src/crewai/llms/third_party/ai_suite.py new file mode 100644 index 000000000..78185a081 --- /dev/null +++ b/src/crewai/llms/third_party/ai_suite.py @@ -0,0 +1,38 @@ +from typing import Any, Dict, List, Optional, Union + +import aisuite as ai + +from crewai.llms.base_llm import BaseLLM + + +class AISuiteLLM(BaseLLM): + def __init__(self, model: str, temperature: Optional[float] = None, **kwargs): + super().__init__(model, temperature, **kwargs) + self.client = ai.Client() + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + completion_params = self._prepare_completion_params(messages, tools) + response = self.client.chat.completions.create(**completion_params) + + return response.choices[0].message.content + + def _prepare_completion_params( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + ) -> Dict[str, Any]: + return { + "model": self.model, + "messages": messages, + "temperature": self.temperature, + "tools": tools, + } + + def supports_function_calling(self) -> bool: + return False diff --git a/src/crewai/utilities/evaluators/crew_evaluator_handler.py b/src/crewai/utilities/evaluators/crew_evaluator_handler.py index 9fcd2886d..984dcf97f 100644 --- a/src/crewai/utilities/evaluators/crew_evaluator_handler.py +++ b/src/crewai/utilities/evaluators/crew_evaluator_handler.py @@ -6,7 +6,7 @@ from rich.console import Console from rich.table import Table from crewai.agent import Agent -from crewai.llm import LLM +from crewai.llm import BaseLLM from crewai.task import Task from crewai.tasks.task_output import TaskOutput from crewai.telemetry import Telemetry @@ -24,7 +24,7 @@ class CrewEvaluator: Attributes: crew (Crew): The crew of agents to evaluate. - eval_llm (LLM): Language model instance to use for evaluations + eval_llm (BaseLLM): Language model instance to use for evaluations tasks_scores (defaultdict): A dictionary to store the scores of the agents for each task. iteration (int): The current iteration of the evaluation. """ @@ -33,7 +33,7 @@ class CrewEvaluator: run_execution_times: defaultdict = defaultdict(list) iteration: int = 0 - def __init__(self, crew, eval_llm: InstanceOf[LLM]): + def __init__(self, crew, eval_llm: InstanceOf[BaseLLM]): self.crew = crew self.llm = eval_llm self._telemetry = Telemetry() diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index 5e20cf768..1eb0a4693 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -2,28 +2,28 @@ import os from typing import Any, Dict, List, Optional, Union from crewai.cli.constants import DEFAULT_LLM_MODEL, ENV_VARS, LITELLM_PARAMS -from crewai.llm import LLM +from crewai.llm import LLM, BaseLLM def create_llm( llm_value: Union[str, LLM, Any, None] = None, -) -> Optional[LLM]: +) -> Optional[LLM | BaseLLM]: """ Creates or returns an LLM instance based on the given llm_value. Args: - llm_value (str | LLM | Any | None): + llm_value (str | BaseLLM | Any | None): - str: The model name (e.g., "gpt-4"). - - LLM: Already instantiated LLM, returned as-is. + - BaseLLM: Already instantiated BaseLLM (including LLM), returned as-is. - Any: Attempt to extract known attributes like model_name, temperature, etc. - None: Use environment-based or fallback default model. Returns: - An LLM instance if successful, or None if something fails. + A BaseLLM instance if successful, or None if something fails. """ - # 1) If llm_value is already an LLM object, return it directly - if isinstance(llm_value, LLM): + # 1) If llm_value is already a BaseLLM or LLM object, return it directly + if isinstance(llm_value, LLM) or isinstance(llm_value, BaseLLM): return llm_value # 2) If llm_value is a string (model name) diff --git a/tests/cassettes/test_custom_llm_implementation.yaml b/tests/cassettes/test_custom_llm_implementation.yaml new file mode 100644 index 000000000..1ec828eaf --- /dev/null +++ b/tests/cassettes/test_custom_llm_implementation.yaml @@ -0,0 +1,107 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": "What is the answer to life, the universe, and everything?"}], + "model": "gpt-4o-mini", "tools": null}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '206' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.61.0 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.61.0 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.8 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + content: "{\n \"id\": \"chatcmpl-B7W6FS0wpfndLdg12G3H6ZAXcYhJi\",\n \"object\": + \"chat.completion\",\n \"created\": 1741131387,\n \"model\": \"gpt-4o-mini-2024-07-18\",\n + \ \"choices\": [\n {\n \"index\": 0,\n \"message\": {\n \"role\": + \"assistant\",\n \"content\": \"The answer to life, the universe, and + everything, famously found in Douglas Adams' \\\"The Hitchhiker's Guide to the + Galaxy,\\\" is the number 42. However, the question itself is left ambiguous, + leading to much speculation and humor in the story.\",\n \"refusal\": + null\n },\n \"logprobs\": null,\n \"finish_reason\": \"stop\"\n + \ }\n ],\n \"usage\": {\n \"prompt_tokens\": 30,\n \"completion_tokens\": + 54,\n \"total_tokens\": 84,\n \"prompt_tokens_details\": {\n \"cached_tokens\": + 0,\n \"audio_tokens\": 0\n },\n \"completion_tokens_details\": {\n + \ \"reasoning_tokens\": 0,\n \"audio_tokens\": 0,\n \"accepted_prediction_tokens\": + 0,\n \"rejected_prediction_tokens\": 0\n }\n },\n \"service_tier\": + \"default\",\n \"system_fingerprint\": \"fp_06737a9306\"\n}\n" + headers: + CF-RAY: + - 91b532234c18cf1f-SJC + Connection: + - keep-alive + Content-Encoding: + - gzip + Content-Type: + - application/json + Date: + - Tue, 04 Mar 2025 23:36:28 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=DgLb6UAE6W4Oeto1Bi2RiKXQVV5TTzkXdXWFdmAEwQQ-1741131388-1.0.1.1-jWQtsT95wOeQbmIxAK7cv8gJWxYi1tQ.IupuJzBDnZr7iEChwVUQBRfnYUBJPDsNly3bakCDArjD_S.FLKwH6xUfvlxgfd4YSBhBPy7bcgw; + path=/; expires=Wed, 05-Mar-25 00:06:28 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=Oa59XCmqjKLKwU34la1hkTunN57JW20E.ZHojvRBfow-1741131388236-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + Transfer-Encoding: + - chunked + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-organization: + - crewai-iuxna1 + openai-processing-ms: + - '776' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999960' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_97824e8fe7c1aca3fbcba7c925388b39 + http_version: HTTP/1.1 + status_code: 200 +version: 1 diff --git a/tests/cassettes/test_custom_llm_within_crew.yaml b/tests/cassettes/test_custom_llm_within_crew.yaml new file mode 100644 index 000000000..9c01ad2f0 --- /dev/null +++ b/tests/cassettes/test_custom_llm_within_crew.yaml @@ -0,0 +1,305 @@ +interactions: +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"role": "system", "content": "You are Say Hi. + You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give + my best complete final answer to the task respond using the exact following + format:\n\nThought: I now can give a great answer\nFinal Answer: Your final + answer must be the great and the most complete as possible, it must be outcome + described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user", + "content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria + for your final answer: A greeting to the user\nyou MUST return the actual complete + content as the final answer, not a summary.\n\nBegin! This is VERY important + to you, use the tools available and give your best Final Answer, your job depends + on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '931' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.61.0 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.61.0 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.8 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n + \ \"code\": \"missing_required_parameter\"\n }\n}" + headers: + CF-RAY: + - 91b54660799a15b4-SJC + Connection: + - keep-alive + Content-Length: + - '219' + Content-Type: + - application/json + Date: + - Tue, 04 Mar 2025 23:50:16 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=OwS.6cyfDpbxxx8vPp4THv5eNoDMQK0qSVN.wSUyOYk-1741132216-1.0.1.1-QBVd08CjfmDBpNnYQM5ILGbTUWKh6SDM9E4ARG4SV2Z9Q4ltFSFLXoo38OGJApUNZmzn4PtRsyAPsHt_dsrHPF6MD17FPcGtrnAHqCjJrfU; + path=/; expires=Wed, 05-Mar-25 00:20:16 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=n_ebDsAOhJm5Mc7OMx8JDiOaZq5qzHCnVxyS3KN0BwA-1741132216951-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-organization: + - crewai-iuxna1 + openai-processing-ms: + - '19' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999974' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_042a4e8f9432f6fde7a02037bb6caafa + http_version: HTTP/1.1 + status_code: 400 +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"role": "system", "content": "You are Say Hi. + You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give + my best complete final answer to the task respond using the exact following + format:\n\nThought: I now can give a great answer\nFinal Answer: Your final + answer must be the great and the most complete as possible, it must be outcome + described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user", + "content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria + for your final answer: A greeting to the user\nyou MUST return the actual complete + content as the final answer, not a summary.\n\nBegin! This is VERY important + to you, use the tools available and give your best Final Answer, your job depends + on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '931' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.61.0 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.61.0 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.8 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n + \ \"code\": \"missing_required_parameter\"\n }\n}" + headers: + CF-RAY: + - 91b54664bb1acef1-SJC + Connection: + - keep-alive + Content-Length: + - '219' + Content-Type: + - application/json + Date: + - Tue, 04 Mar 2025 23:50:17 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=.wGU4pJEajaSzFWjp05TBQwWbCNA2CgpYNu7UYOzbbM-1741132217-1.0.1.1-NoLiAx4qkplllldYYxZCOSQGsX6hsPUJIEyqmt84B3g7hjW1s7.jk9C9PYzXagHWjT0sQ9Ny4LZBA94lDJTfDBZpty8NJQha7ZKW0P_msH8; + path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=GAjgJjVLtN49bMeWdWZDYLLkEkK51z5kxK4nKqhAzxY-1741132217161-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-organization: + - crewai-iuxna1 + openai-processing-ms: + - '25' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999974' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_7a1d027da1ef4468e861e570c72e98fb + http_version: HTTP/1.1 + status_code: 400 +- request: + body: '{"messages": [{"role": "system", "content": "You are a helpful assistant."}, + {"role": "user", "content": [{"role": "system", "content": "You are Say Hi. + You just say hi to the user\nYour personal goal is: Say hi to the user\nTo give + my best complete final answer to the task respond using the exact following + format:\n\nThought: I now can give a great answer\nFinal Answer: Your final + answer must be the great and the most complete as possible, it must be outcome + described.\n\nI MUST use these formats, my job depends on it!"}, {"role": "user", + "content": "\nCurrent Task: Say hi to the user\n\nThis is the expected criteria + for your final answer: A greeting to the user\nyou MUST return the actual complete + content as the final answer, not a summary.\n\nBegin! This is VERY important + to you, use the tools available and give your best Final Answer, your job depends + on it!\n\nThought:"}]}], "model": "gpt-4o-mini", "tools": null}' + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '931' + content-type: + - application/json + host: + - api.openai.com + user-agent: + - OpenAI/Python 1.61.0 + x-stainless-arch: + - arm64 + x-stainless-async: + - 'false' + x-stainless-lang: + - python + x-stainless-os: + - MacOS + x-stainless-package-version: + - 1.61.0 + x-stainless-retry-count: + - '0' + x-stainless-runtime: + - CPython + x-stainless-runtime-version: + - 3.12.8 + method: POST + uri: https://api.openai.com/v1/chat/completions + response: + content: "{\n \"error\": {\n \"message\": \"Missing required parameter: 'messages[1].content[0].type'.\",\n + \ \"type\": \"invalid_request_error\",\n \"param\": \"messages[1].content[0].type\",\n + \ \"code\": \"missing_required_parameter\"\n }\n}" + headers: + CF-RAY: + - 91b54666183beb22-SJC + Connection: + - keep-alive + Content-Length: + - '219' + Content-Type: + - application/json + Date: + - Tue, 04 Mar 2025 23:50:17 GMT + Server: + - cloudflare + Set-Cookie: + - __cf_bm=VwjWHHpkZMJlosI9RbMqxYDBS1t0JK4tWpAy4lST2QM-1741132217-1.0.1.1-u7PU.ZvVBTXNB5R8vaYfWdPXAjWZ3ZcTAy656VaGDZmKIckk5od._eQdn0W0EGVtEMm3TuF60z4GZAPDwMYvb3_3cw1RuEMmQbp4IIrl7VY; + path=/; expires=Wed, 05-Mar-25 00:20:17 GMT; domain=.api.openai.com; HttpOnly; + Secure; SameSite=None + - _cfuvid=NglAAsQBoiabMuuHFgilRjflSPFqS38VGKnGyweuCuw-1741132217438-0.0.1.1-604800000; + path=/; domain=.api.openai.com; HttpOnly; Secure; SameSite=None + X-Content-Type-Options: + - nosniff + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + cf-cache-status: + - DYNAMIC + openai-organization: + - crewai-iuxna1 + openai-processing-ms: + - '56' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + x-ratelimit-limit-requests: + - '30000' + x-ratelimit-limit-tokens: + - '150000000' + x-ratelimit-remaining-requests: + - '29999' + x-ratelimit-remaining-tokens: + - '149999974' + x-ratelimit-reset-requests: + - 2ms + x-ratelimit-reset-tokens: + - 0s + x-request-id: + - req_3c335b308b82cc2214783a4bf2fc0fd4 + http_version: HTTP/1.1 + status_code: 400 +version: 1 diff --git a/tests/custom_llm_test.py b/tests/custom_llm_test.py new file mode 100644 index 000000000..6bee5b31d --- /dev/null +++ b/tests/custom_llm_test.py @@ -0,0 +1,359 @@ +from typing import Any, Dict, List, Optional, Union +from unittest.mock import Mock + +import pytest + +from crewai import Agent, Crew, Process, Task +from crewai.llms.base_llm import BaseLLM +from crewai.utilities.llm_utils import create_llm + + +class CustomLLM(BaseLLM): + """Custom LLM implementation for testing. + + This is a simple implementation of the BaseLLM abstract base class + that returns a predefined response for testing purposes. + """ + + def __init__(self, response="Default response", model="test-model"): + """Initialize the CustomLLM with a predefined response. + + Args: + response: The predefined response to return from call(). + """ + super().__init__(model=model) + self.response = response + self.call_count = 0 + + def call( + self, + messages, + tools=None, + callbacks=None, + available_functions=None, + ): + """ + Mock LLM call that returns a predefined response. + Properly formats messages to match OpenAI's expected structure. + """ + self.call_count += 1 + + # If input is a string, convert to proper message format + if isinstance(messages, str): + messages = [{"role": "user", "content": messages}] + + # Ensure each message has properly formatted content + for message in messages: + if isinstance(message["content"], str): + message["content"] = [{"type": "text", "text": message["content"]}] + + # Return predefined response in expected format + if "Thought:" in str(messages): + return f"Thought: I will say hi\nFinal Answer: {self.response}" + return self.response + + def supports_function_calling(self) -> bool: + """Return False to indicate that function calling is not supported. + + Returns: + False, indicating that this LLM does not support function calling. + """ + return False + + def supports_stop_words(self) -> bool: + """Return False to indicate that stop words are not supported. + + Returns: + False, indicating that this LLM does not support stop words. + """ + return False + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 4096, a typical context window size for modern LLMs. + """ + return 4096 + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_custom_llm_implementation(): + """Test that a custom LLM implementation works with create_llm.""" + custom_llm = CustomLLM(response="The answer is 42") + + # Test that create_llm returns the custom LLM instance directly + result_llm = create_llm(custom_llm) + + assert result_llm is custom_llm + + # Test calling the custom LLM + response = result_llm.call( + "What is the answer to life, the universe, and everything?" + ) + + # Verify that the response from the custom LLM was used + assert "42" in response + + +@pytest.mark.vcr(filter_headers=["authorization"]) +def test_custom_llm_within_crew(): + """Test that a custom LLM implementation works with create_llm.""" + custom_llm = CustomLLM(response="Hello! Nice to meet you!", model="test-model") + + agent = Agent( + role="Say Hi", + goal="Say hi to the user", + backstory="""You just say hi to the user""", + llm=custom_llm, + ) + + task = Task( + description="Say hi to the user", + expected_output="A greeting to the user", + agent=agent, + ) + + crew = Crew( + agents=[agent], + tasks=[task], + process=Process.sequential, + ) + + result = crew.kickoff() + + # Assert the LLM was called + assert custom_llm.call_count > 0 + # Assert we got a response + assert "Hello!" in result.raw + + +def test_custom_llm_message_formatting(): + """Test that the custom LLM properly formats messages""" + custom_llm = CustomLLM(response="Test response", model="test-model") + + # Test with string input + result = custom_llm.call("Test message") + assert result == "Test response" + + # Test with message list + messages = [ + {"role": "system", "content": "System message"}, + {"role": "user", "content": "User message"}, + ] + result = custom_llm.call(messages) + assert result == "Test response" + + +class JWTAuthLLM(BaseLLM): + """Custom LLM implementation with JWT authentication.""" + + def __init__(self, jwt_token: str): + super().__init__(model="test-model") + if not jwt_token or not isinstance(jwt_token, str): + raise ValueError("Invalid JWT token") + self.jwt_token = jwt_token + self.calls = [] + self.stop = [] + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Record the call and return a predefined response.""" + self.calls.append( + { + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + } + ) + # In a real implementation, this would use the JWT token to authenticate + # with an external service + return "Response from JWT-authenticated LLM" + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported.""" + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported.""" + return True + + def get_context_window_size(self) -> int: + """Return a default context window size.""" + return 8192 + + +def test_custom_llm_with_jwt_auth(): + """Test a custom LLM implementation with JWT authentication.""" + jwt_llm = JWTAuthLLM(jwt_token="example.jwt.token") + + # Test that create_llm returns the JWT-authenticated LLM instance directly + result_llm = create_llm(jwt_llm) + + assert result_llm is jwt_llm + + # Test calling the JWT-authenticated LLM + response = result_llm.call("Test message") + + # Verify that the JWT-authenticated LLM was called + assert len(jwt_llm.calls) > 0 + # Verify that the response from the JWT-authenticated LLM was used + assert response == "Response from JWT-authenticated LLM" + + +def test_jwt_auth_llm_validation(): + """Test that JWT token validation works correctly.""" + # Test with invalid JWT token (empty string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token="") + + # Test with invalid JWT token (non-string) + with pytest.raises(ValueError, match="Invalid JWT token"): + JWTAuthLLM(jwt_token=None) + + +class TimeoutHandlingLLM(BaseLLM): + """Custom LLM implementation with timeout handling and retry logic.""" + + def __init__(self, max_retries: int = 3, timeout: int = 30): + """Initialize the TimeoutHandlingLLM with retry and timeout settings. + + Args: + max_retries: Maximum number of retry attempts. + timeout: Timeout in seconds for each API call. + """ + super().__init__(model="test-model") + self.max_retries = max_retries + self.timeout = timeout + self.calls = [] + self.stop = [] + self.fail_count = 0 # Number of times to simulate failure + + def call( + self, + messages: Union[str, List[Dict[str, str]]], + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> Union[str, Any]: + """Simulate API calls with timeout handling and retry logic. + + Args: + messages: Input messages for the LLM. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + A response string based on whether this is the first attempt or a retry. + + Raises: + TimeoutError: If all retry attempts fail. + """ + # Record the initial call + self.calls.append( + { + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + "attempt": 0, + } + ) + + # Simulate retry logic + for attempt in range(self.max_retries): + # Skip the first attempt recording since we already did that above + if attempt == 0: + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError( + f"LLM request failed after {self.max_retries} attempts" + ) + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on first attempt + return "First attempt response" + else: + # This is a retry attempt (attempt > 0) + # Always record retry attempts + self.calls.append( + { + "retry_attempt": attempt, + "messages": messages, + "tools": tools, + "callbacks": callbacks, + "available_functions": available_functions, + } + ) + + # Simulate a failure if fail_count > 0 + if self.fail_count > 0: + self.fail_count -= 1 + # If we've used all retries, raise an error + if attempt == self.max_retries - 1: + raise TimeoutError( + f"LLM request failed after {self.max_retries} attempts" + ) + # Otherwise, continue to the next attempt (simulating backoff) + continue + else: + # Success on retry + return "Response after retry" + + def supports_function_calling(self) -> bool: + """Return True to indicate that function calling is supported. + + Returns: + True, indicating that this LLM supports function calling. + """ + return True + + def supports_stop_words(self) -> bool: + """Return True to indicate that stop words are supported. + + Returns: + True, indicating that this LLM supports stop words. + """ + return True + + def get_context_window_size(self) -> int: + """Return a default context window size. + + Returns: + 8192, a typical context window size for modern LLMs. + """ + return 8192 + + +def test_timeout_handling_llm(): + """Test a custom LLM implementation with timeout handling and retry logic.""" + # Test successful first attempt + llm = TimeoutHandlingLLM() + response = llm.call("Test message") + assert response == "First attempt response" + assert len(llm.calls) == 1 + + # Test successful retry + llm = TimeoutHandlingLLM() + llm.fail_count = 1 # Fail once, then succeed + response = llm.call("Test message") + assert response == "Response after retry" + assert len(llm.calls) == 2 # Initial call + successful retry call + + # Test failure after all retries + llm = TimeoutHandlingLLM(max_retries=2) + llm.fail_count = 2 # Fail twice, which is all retries + with pytest.raises(TimeoutError, match="LLM request failed after 2 attempts"): + llm.call("Test message") + assert len(llm.calls) == 2 # Initial call + failed retry attempt diff --git a/uv.lock b/uv.lock index d52485124..3a3f30bab 100644 --- a/uv.lock +++ b/uv.lock @@ -139,6 +139,18 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/76/ac/a7305707cb852b7e16ff80eaf5692309bde30e2b1100a1fcacdc8f731d97/aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17", size = 7617 }, ] +[[package]] +name = "aisuite" +version = "0.1.10" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "httpx" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/6a/9d/c7a8a76abb9011dd2bc9a5cb8ffa8231640e20bbdae177ce9ab6cb67c66c/aisuite-0.1.10.tar.gz", hash = "sha256:170e62d4c91fecb22e82a04e058154a111cef473681171e5df7346272e77f414", size = 29052 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/58/c2/9a34a01516de107e5f9406dbfd319b6004340708101d67fa107373da4058/aisuite-0.1.10-py3-none-any.whl", hash = "sha256:c8510ebe38d6546b6a06819171e201fcaf0bf9ae020ffcfe19b6bd90430781ad", size = 43984 }, +] + [[package]] name = "alembic" version = "1.13.3" @@ -651,6 +663,9 @@ dependencies = [ agentops = [ { name = "agentops" }, ] +aisuite = [ + { name = "aisuite" }, +] docling = [ { name = "docling" }, ] @@ -698,6 +713,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "agentops", marker = "extra == 'agentops'", specifier = ">=0.3.0" }, + { name = "aisuite", marker = "extra == 'aisuite'", specifier = ">=0.1.10" }, { name = "appdirs", specifier = ">=1.4.4" }, { name = "auth0-python", specifier = ">=4.7.1" }, { name = "blinker", specifier = ">=1.9.0" },