chore: modernize LLM interface typing and add constants (#3483)

* chore: update LLM interfaces to Python 3.10+ typing

* fix: add missing stop attribute to mock LLM and improve test infrastructure

* fix: correct type ignore comment for aisuite import
This commit is contained in:
Greyson LaLonde
2025-09-10 08:30:49 -04:00
committed by GitHub
parent 6676d94ba1
commit 83682d511f
5 changed files with 126 additions and 58 deletions

View File

@@ -16,3 +16,4 @@ repos:
entry: uv run mypy
language: system
types: [python]
exclude: ^tests/

View File

@@ -133,6 +133,9 @@ select = [
]
ignore = ["E501"] # ignore line too long
[tool.ruff.lint.per-file-ignores]
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
[tool.mypy]
exclude = ["src/crewai/cli/templates", "tests"]

View File

@@ -1,5 +1,14 @@
"""Base LLM abstract class for CrewAI.
This module provides the abstract base class for all LLM implementations
in CrewAI.
"""
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Final
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
class BaseLLM(ABC):
@@ -15,41 +24,38 @@ class BaseLLM(ABC):
messages when things go wrong.
Attributes:
stop (list): A list of stop sequences that the LLM should use to stop generation.
This is used by the CrewAgentExecutor and other components.
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: A list of stop sequences that the LLM should use to stop generation.
"""
model: str
temperature: Optional[float] = None
stop: Optional[List[str]] = None
def __init__(
self,
model: str,
temperature: Optional[float] = None,
):
temperature: float | None = None,
stop: list[str] | None = None,
) -> None:
"""Initialize the BaseLLM with default attributes.
This constructor sets default values for attributes that are expected
by the CrewAgentExecutor and other components.
All custom LLM implementations should call super().__init__() to ensure
that these default attributes are properly initialized.
Args:
model: The model identifier/name.
temperature: Optional temperature setting for response generation.
stop: Optional list of stop sequences for generation.
"""
self.model = model
self.temperature = temperature
self.stop = []
self.stop: list[str] = stop or []
@abstractmethod
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call the LLM with the given messages.
Args:
@@ -64,6 +70,7 @@ class BaseLLM(ABC):
available_functions: Optional dict mapping function names to callables
that can be invoked by the LLM.
from_task: Optional task caller to be used for the LLM call.
from_agent: Optional agent caller to be used for the LLM call.
Returns:
Either a text response from the LLM (str) or
@@ -74,21 +81,20 @@ class BaseLLM(ABC):
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
pass
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
Returns:
bool: True if the LLM supports stop words, False otherwise.
True if the LLM supports stop words, False otherwise.
"""
return True # Default implementation assumes support for stop words
return DEFAULT_SUPPORTS_STOP_WORDS
def get_context_window_size(self) -> int:
"""Get the context window size for the LLM.
Returns:
int: The number of tokens/characters the model can handle.
The number of tokens/characters the model can handle.
"""
# Default implementation - subclasses should override with model-specific values
return 4096
return DEFAULT_CONTEXT_WINDOW_SIZE

View File

@@ -1,24 +1,62 @@
from typing import Any, Dict, List, Optional, Union
"""AI Suite LLM integration for CrewAI.
import aisuite as ai
This module provides integration with AI Suite for LLM capabilities.
"""
from typing import Any
import aisuite as ai # type: ignore
from crewai.llms.base_llm import BaseLLM
class AISuiteLLM(BaseLLM):
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
super().__init__(model, temperature, **kwargs)
"""AI Suite LLM implementation.
This class provides integration with AI Suite models through the BaseLLM interface.
"""
def __init__(
self,
model: str,
temperature: float | None = None,
stop: list[str] | None = None,
**kwargs: Any,
) -> None:
"""Initialize the AI Suite LLM.
Args:
model: The model identifier for AI Suite.
temperature: Optional temperature setting for response generation.
stop: Optional list of stop sequences for generation.
**kwargs: Additional keyword arguments passed to the AI Suite client.
"""
super().__init__(model, temperature, stop)
self.client = ai.Client()
self.kwargs = kwargs
def call(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
from_task: Optional[Any] = None,
from_agent: Optional[Any] = None,
) -> Union[str, Any]:
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str | Any:
"""Call the AI Suite LLM with the given messages.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
from_task: Optional task caller.
from_agent: Optional agent caller.
Returns:
The text response from the LLM.
"""
completion_params = self._prepare_completion_params(messages, tools)
response = self.client.chat.completions.create(**completion_params)
@@ -26,15 +64,35 @@ class AISuiteLLM(BaseLLM):
def _prepare_completion_params(
self,
messages: Union[str, List[Dict[str, str]]],
tools: Optional[List[dict]] = None,
) -> Dict[str, Any]:
return {
messages: str | list[dict[str, str]],
tools: list[dict] | None = None,
) -> dict[str, Any]:
"""Prepare parameters for the AI Suite completion call.
Args:
messages: Input messages for the LLM.
tools: Optional list of tool schemas.
Returns:
Dictionary of parameters for the completion API.
"""
params: dict[str, Any] = {
"model": self.model,
"messages": messages,
"temperature": self.temperature,
"tools": tools,
**self.kwargs,
}
if self.stop:
params["stop"] = self.stop
return params
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
Returns:
False, as AI Suite does not currently support function calling.
"""
return False

View File

@@ -1,19 +1,18 @@
from collections import defaultdict
from typing import cast
from unittest.mock import Mock
from unittest.mock import Mock, patch
import pytest
from pydantic import BaseModel, Field
from crewai import LLM, Agent
from crewai.flow import Flow, start
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.tools import BaseTool
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
from crewai.flow import Flow, start
from crewai.lite_agent import LiteAgent, LiteAgentOutput
from crewai.llms.base_llm import BaseLLM
from unittest.mock import patch
from crewai.tools import BaseTool
# A simple test tool
@@ -37,10 +36,9 @@ class WebSearchTool(BaseTool):
# This is a mock implementation
if "tokyo" in query.lower():
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
elif "climate change" in query.lower() and "coral" in query.lower():
if "climate change" in query.lower() and "coral" in query.lower():
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
else:
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
# Define Mock Calculator Tool
@@ -53,10 +51,11 @@ class CalculatorTool(BaseTool):
def _run(self, expression: str) -> str:
"""Calculate the result of a mathematical expression."""
try:
result = eval(expression, {"__builtins__": {}})
# Using eval with restricted builtins for test purposes only
result = eval(expression, {"__builtins__": {}}) # noqa: S307
return f"The result of {expression} is {result}"
except Exception as e:
return f"Error calculating {expression}: {str(e)}"
return f"Error calculating {expression}: {e!s}"
# Define a custom response format using Pydantic
@@ -148,12 +147,12 @@ def test_lite_agent_with_tools():
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
)
assert (
"21 million" in result.raw or "37 million" in result.raw
), "Agent should find Tokyo's population"
assert (
"per square kilometer" in result.raw
), "Agent should calculate population density"
assert "21 million" in result.raw or "37 million" in result.raw, (
"Agent should find Tokyo's population"
)
assert "per square kilometer" in result.raw, (
"Agent should calculate population density"
)
received_events = []
@@ -294,6 +293,7 @@ def test_sets_parent_flow_when_inside_flow():
mock_llm = Mock(spec=LLM)
mock_llm.call.return_value = "Test response"
mock_llm.stop = []
class MyFlow(Flow):
@start()