mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
chore: modernize LLM interface typing and add constants (#3483)
* chore: update LLM interfaces to Python 3.10+ typing * fix: add missing stop attribute to mock LLM and improve test infrastructure * fix: correct type ignore comment for aisuite import
This commit is contained in:
@@ -16,3 +16,4 @@ repos:
|
||||
entry: uv run mypy
|
||||
language: system
|
||||
types: [python]
|
||||
exclude: ^tests/
|
||||
|
||||
@@ -133,6 +133,9 @@ select = [
|
||||
]
|
||||
ignore = ["E501"] # ignore line too long
|
||||
|
||||
[tool.ruff.lint.per-file-ignores]
|
||||
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
|
||||
|
||||
[tool.mypy]
|
||||
exclude = ["src/crewai/cli/templates", "tests"]
|
||||
|
||||
|
||||
@@ -1,5 +1,14 @@
|
||||
"""Base LLM abstract class for CrewAI.
|
||||
|
||||
This module provides the abstract base class for all LLM implementations
|
||||
in CrewAI.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Final
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -15,41 +24,38 @@ class BaseLLM(ABC):
|
||||
messages when things go wrong.
|
||||
|
||||
Attributes:
|
||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
||||
This is used by the CrewAgentExecutor and other components.
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||
"""
|
||||
|
||||
model: str
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: Optional[float] = None,
|
||||
):
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
) -> None:
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
This constructor sets default values for attributes that are expected
|
||||
by the CrewAgentExecutor and other components.
|
||||
|
||||
All custom LLM implementations should call super().__init__() to ensure
|
||||
that these default attributes are properly initialized.
|
||||
Args:
|
||||
model: The model identifier/name.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
"""
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.stop = []
|
||||
self.stop: list[str] = stop or []
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the LLM with the given messages.
|
||||
|
||||
Args:
|
||||
@@ -64,6 +70,7 @@ class BaseLLM(ABC):
|
||||
available_functions: Optional dict mapping function names to callables
|
||||
that can be invoked by the LLM.
|
||||
from_task: Optional task caller to be used for the LLM call.
|
||||
from_agent: Optional agent caller to be used for the LLM call.
|
||||
|
||||
Returns:
|
||||
Either a text response from the LLM (str) or
|
||||
@@ -74,21 +81,20 @@ class BaseLLM(ABC):
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
pass
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
Returns:
|
||||
bool: True if the LLM supports stop words, False otherwise.
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
return True # Default implementation assumes support for stop words
|
||||
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size for the LLM.
|
||||
|
||||
Returns:
|
||||
int: The number of tokens/characters the model can handle.
|
||||
The number of tokens/characters the model can handle.
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return 4096
|
||||
return DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
|
||||
88
src/crewai/llms/third_party/ai_suite.py
vendored
88
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -1,24 +1,62 @@
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
"""AI Suite LLM integration for CrewAI.
|
||||
|
||||
import aisuite as ai
|
||||
This module provides integration with AI Suite for LLM capabilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aisuite as ai # type: ignore
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class AISuiteLLM(BaseLLM):
|
||||
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
||||
super().__init__(model, temperature, **kwargs)
|
||||
"""AI Suite LLM implementation.
|
||||
|
||||
This class provides integration with AI Suite models through the BaseLLM interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the AI Suite LLM.
|
||||
|
||||
Args:
|
||||
model: The model identifier for AI Suite.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional keyword arguments passed to the AI Suite client.
|
||||
"""
|
||||
super().__init__(model, temperature, stop)
|
||||
self.client = ai.Client()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
from_task: Optional[Any] = None,
|
||||
from_agent: Optional[Any] = None,
|
||||
) -> Union[str, Any]:
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the AI Suite LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task caller.
|
||||
from_agent: Optional agent caller.
|
||||
|
||||
Returns:
|
||||
The text response from the LLM.
|
||||
"""
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
@@ -26,15 +64,35 @@ class AISuiteLLM(BaseLLM):
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the AI Suite completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for the completion API.
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"tools": tools,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
if self.stop:
|
||||
params["stop"] = self.stop
|
||||
|
||||
return params
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
False, as AI Suite does not currently support function calling.
|
||||
"""
|
||||
return False
|
||||
|
||||
@@ -1,19 +1,18 @@
|
||||
from collections import defaultdict
|
||||
from typing import cast
|
||||
from unittest.mock import Mock
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai import LLM, Agent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.tools import BaseTool
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
||||
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
||||
from crewai.flow import Flow, start
|
||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from unittest.mock import patch
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
|
||||
# A simple test tool
|
||||
@@ -37,10 +36,9 @@ class WebSearchTool(BaseTool):
|
||||
# This is a mock implementation
|
||||
if "tokyo" in query.lower():
|
||||
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
||||
elif "climate change" in query.lower() and "coral" in query.lower():
|
||||
if "climate change" in query.lower() and "coral" in query.lower():
|
||||
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
||||
else:
|
||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||
|
||||
|
||||
# Define Mock Calculator Tool
|
||||
@@ -53,10 +51,11 @@ class CalculatorTool(BaseTool):
|
||||
def _run(self, expression: str) -> str:
|
||||
"""Calculate the result of a mathematical expression."""
|
||||
try:
|
||||
result = eval(expression, {"__builtins__": {}})
|
||||
# Using eval with restricted builtins for test purposes only
|
||||
result = eval(expression, {"__builtins__": {}}) # noqa: S307
|
||||
return f"The result of {expression} is {result}"
|
||||
except Exception as e:
|
||||
return f"Error calculating {expression}: {str(e)}"
|
||||
return f"Error calculating {expression}: {e!s}"
|
||||
|
||||
|
||||
# Define a custom response format using Pydantic
|
||||
@@ -148,12 +147,12 @@ def test_lite_agent_with_tools():
|
||||
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
|
||||
)
|
||||
|
||||
assert (
|
||||
"21 million" in result.raw or "37 million" in result.raw
|
||||
), "Agent should find Tokyo's population"
|
||||
assert (
|
||||
"per square kilometer" in result.raw
|
||||
), "Agent should calculate population density"
|
||||
assert "21 million" in result.raw or "37 million" in result.raw, (
|
||||
"Agent should find Tokyo's population"
|
||||
)
|
||||
assert "per square kilometer" in result.raw, (
|
||||
"Agent should calculate population density"
|
||||
)
|
||||
|
||||
received_events = []
|
||||
|
||||
@@ -294,6 +293,7 @@ def test_sets_parent_flow_when_inside_flow():
|
||||
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Test response"
|
||||
mock_llm.stop = []
|
||||
|
||||
class MyFlow(Flow):
|
||||
@start()
|
||||
|
||||
Reference in New Issue
Block a user