mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-10 08:38:30 +00:00
chore: modernize LLM interface typing and add constants (#3483)
* chore: update LLM interfaces to Python 3.10+ typing * fix: add missing stop attribute to mock LLM and improve test infrastructure * fix: correct type ignore comment for aisuite import
This commit is contained in:
@@ -16,3 +16,4 @@ repos:
|
|||||||
entry: uv run mypy
|
entry: uv run mypy
|
||||||
language: system
|
language: system
|
||||||
types: [python]
|
types: [python]
|
||||||
|
exclude: ^tests/
|
||||||
|
|||||||
@@ -133,6 +133,9 @@ select = [
|
|||||||
]
|
]
|
||||||
ignore = ["E501"] # ignore line too long
|
ignore = ["E501"] # ignore line too long
|
||||||
|
|
||||||
|
[tool.ruff.lint.per-file-ignores]
|
||||||
|
"tests/**/*.py" = ["S101"] # Allow assert statements in tests
|
||||||
|
|
||||||
[tool.mypy]
|
[tool.mypy]
|
||||||
exclude = ["src/crewai/cli/templates", "tests"]
|
exclude = ["src/crewai/cli/templates", "tests"]
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,14 @@
|
|||||||
|
"""Base LLM abstract class for CrewAI.
|
||||||
|
|
||||||
|
This module provides the abstract base class for all LLM implementations
|
||||||
|
in CrewAI.
|
||||||
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Final
|
||||||
|
|
||||||
|
DEFAULT_CONTEXT_WINDOW_SIZE: Final[int] = 4096
|
||||||
|
DEFAULT_SUPPORTS_STOP_WORDS: Final[bool] = True
|
||||||
|
|
||||||
|
|
||||||
class BaseLLM(ABC):
|
class BaseLLM(ABC):
|
||||||
@@ -15,41 +24,38 @@ class BaseLLM(ABC):
|
|||||||
messages when things go wrong.
|
messages when things go wrong.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
stop (list): A list of stop sequences that the LLM should use to stop generation.
|
model: The model identifier/name.
|
||||||
This is used by the CrewAgentExecutor and other components.
|
temperature: Optional temperature setting for response generation.
|
||||||
|
stop: A list of stop sequences that the LLM should use to stop generation.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
model: str
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
temperature: Optional[float] = None,
|
temperature: float | None = None,
|
||||||
):
|
stop: list[str] | None = None,
|
||||||
|
) -> None:
|
||||||
"""Initialize the BaseLLM with default attributes.
|
"""Initialize the BaseLLM with default attributes.
|
||||||
|
|
||||||
This constructor sets default values for attributes that are expected
|
Args:
|
||||||
by the CrewAgentExecutor and other components.
|
model: The model identifier/name.
|
||||||
|
temperature: Optional temperature setting for response generation.
|
||||||
All custom LLM implementations should call super().__init__() to ensure
|
stop: Optional list of stop sequences for generation.
|
||||||
that these default attributes are properly initialized.
|
|
||||||
"""
|
"""
|
||||||
self.model = model
|
self.model = model
|
||||||
self.temperature = temperature
|
self.temperature = temperature
|
||||||
self.stop = []
|
self.stop: list[str] = stop or []
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
"""Call the LLM with the given messages.
|
"""Call the LLM with the given messages.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@@ -64,6 +70,7 @@ class BaseLLM(ABC):
|
|||||||
available_functions: Optional dict mapping function names to callables
|
available_functions: Optional dict mapping function names to callables
|
||||||
that can be invoked by the LLM.
|
that can be invoked by the LLM.
|
||||||
from_task: Optional task caller to be used for the LLM call.
|
from_task: Optional task caller to be used for the LLM call.
|
||||||
|
from_agent: Optional agent caller to be used for the LLM call.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Either a text response from the LLM (str) or
|
Either a text response from the LLM (str) or
|
||||||
@@ -74,21 +81,20 @@ class BaseLLM(ABC):
|
|||||||
TimeoutError: If the LLM request times out.
|
TimeoutError: If the LLM request times out.
|
||||||
RuntimeError: If the LLM request fails for other reasons.
|
RuntimeError: If the LLM request fails for other reasons.
|
||||||
"""
|
"""
|
||||||
pass
|
|
||||||
|
|
||||||
def supports_stop_words(self) -> bool:
|
def supports_stop_words(self) -> bool:
|
||||||
"""Check if the LLM supports stop words.
|
"""Check if the LLM supports stop words.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
bool: True if the LLM supports stop words, False otherwise.
|
True if the LLM supports stop words, False otherwise.
|
||||||
"""
|
"""
|
||||||
return True # Default implementation assumes support for stop words
|
return DEFAULT_SUPPORTS_STOP_WORDS
|
||||||
|
|
||||||
def get_context_window_size(self) -> int:
|
def get_context_window_size(self) -> int:
|
||||||
"""Get the context window size for the LLM.
|
"""Get the context window size for the LLM.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
int: The number of tokens/characters the model can handle.
|
The number of tokens/characters the model can handle.
|
||||||
"""
|
"""
|
||||||
# Default implementation - subclasses should override with model-specific values
|
# Default implementation - subclasses should override with model-specific values
|
||||||
return 4096
|
return DEFAULT_CONTEXT_WINDOW_SIZE
|
||||||
|
|||||||
88
src/crewai/llms/third_party/ai_suite.py
vendored
88
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -1,24 +1,62 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
"""AI Suite LLM integration for CrewAI.
|
||||||
|
|
||||||
import aisuite as ai
|
This module provides integration with AI Suite for LLM capabilities.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
import aisuite as ai # type: ignore
|
||||||
|
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
|
|
||||||
|
|
||||||
class AISuiteLLM(BaseLLM):
|
class AISuiteLLM(BaseLLM):
|
||||||
def __init__(self, model: str, temperature: Optional[float] = None, **kwargs):
|
"""AI Suite LLM implementation.
|
||||||
super().__init__(model, temperature, **kwargs)
|
|
||||||
|
This class provides integration with AI Suite models through the BaseLLM interface.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
temperature: float | None = None,
|
||||||
|
stop: list[str] | None = None,
|
||||||
|
**kwargs: Any,
|
||||||
|
) -> None:
|
||||||
|
"""Initialize the AI Suite LLM.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model: The model identifier for AI Suite.
|
||||||
|
temperature: Optional temperature setting for response generation.
|
||||||
|
stop: Optional list of stop sequences for generation.
|
||||||
|
**kwargs: Additional keyword arguments passed to the AI Suite client.
|
||||||
|
"""
|
||||||
|
super().__init__(model, temperature, stop)
|
||||||
self.client = ai.Client()
|
self.client = ai.Client()
|
||||||
|
self.kwargs = kwargs
|
||||||
|
|
||||||
def call(
|
def call(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
callbacks: Optional[List[Any]] = None,
|
callbacks: list[Any] | None = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: dict[str, Any] | None = None,
|
||||||
from_task: Optional[Any] = None,
|
from_task: Any | None = None,
|
||||||
from_agent: Optional[Any] = None,
|
from_agent: Any | None = None,
|
||||||
) -> Union[str, Any]:
|
) -> str | Any:
|
||||||
|
"""Call the AI Suite LLM with the given messages.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas for function calling.
|
||||||
|
callbacks: Optional list of callback functions.
|
||||||
|
available_functions: Optional dict mapping function names to callables.
|
||||||
|
from_task: Optional task caller.
|
||||||
|
from_agent: Optional agent caller.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
The text response from the LLM.
|
||||||
|
"""
|
||||||
completion_params = self._prepare_completion_params(messages, tools)
|
completion_params = self._prepare_completion_params(messages, tools)
|
||||||
response = self.client.chat.completions.create(**completion_params)
|
response = self.client.chat.completions.create(**completion_params)
|
||||||
|
|
||||||
@@ -26,15 +64,35 @@ class AISuiteLLM(BaseLLM):
|
|||||||
|
|
||||||
def _prepare_completion_params(
|
def _prepare_completion_params(
|
||||||
self,
|
self,
|
||||||
messages: Union[str, List[Dict[str, str]]],
|
messages: str | list[dict[str, str]],
|
||||||
tools: Optional[List[dict]] = None,
|
tools: list[dict] | None = None,
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
"""Prepare parameters for the AI Suite completion call.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
messages: Input messages for the LLM.
|
||||||
|
tools: Optional list of tool schemas.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary of parameters for the completion API.
|
||||||
|
"""
|
||||||
|
params: dict[str, Any] = {
|
||||||
"model": self.model,
|
"model": self.model,
|
||||||
"messages": messages,
|
"messages": messages,
|
||||||
"temperature": self.temperature,
|
"temperature": self.temperature,
|
||||||
"tools": tools,
|
"tools": tools,
|
||||||
|
**self.kwargs,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if self.stop:
|
||||||
|
params["stop"] = self.stop
|
||||||
|
|
||||||
|
return params
|
||||||
|
|
||||||
def supports_function_calling(self) -> bool:
|
def supports_function_calling(self) -> bool:
|
||||||
|
"""Check if the LLM supports function calling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
False, as AI Suite does not currently support function calling.
|
||||||
|
"""
|
||||||
return False
|
return False
|
||||||
|
|||||||
@@ -1,19 +1,18 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import cast
|
from typing import cast
|
||||||
from unittest.mock import Mock
|
from unittest.mock import Mock, patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
from crewai import LLM, Agent
|
from crewai import LLM, Agent
|
||||||
from crewai.flow import Flow, start
|
|
||||||
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
|
||||||
from crewai.tools import BaseTool
|
|
||||||
from crewai.events.event_bus import crewai_event_bus
|
from crewai.events.event_bus import crewai_event_bus
|
||||||
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
from crewai.events.types.agent_events import LiteAgentExecutionStartedEvent
|
||||||
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
from crewai.events.types.tool_usage_events import ToolUsageStartedEvent
|
||||||
|
from crewai.flow import Flow, start
|
||||||
|
from crewai.lite_agent import LiteAgent, LiteAgentOutput
|
||||||
from crewai.llms.base_llm import BaseLLM
|
from crewai.llms.base_llm import BaseLLM
|
||||||
from unittest.mock import patch
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
# A simple test tool
|
# A simple test tool
|
||||||
@@ -37,10 +36,9 @@ class WebSearchTool(BaseTool):
|
|||||||
# This is a mock implementation
|
# This is a mock implementation
|
||||||
if "tokyo" in query.lower():
|
if "tokyo" in query.lower():
|
||||||
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
||||||
elif "climate change" in query.lower() and "coral" in query.lower():
|
if "climate change" in query.lower() and "coral" in query.lower():
|
||||||
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
return "Climate change severely impacts coral reefs through: 1) Ocean warming causing coral bleaching, 2) Ocean acidification reducing calcification, 3) Sea level rise affecting light availability, 4) Increased storm frequency damaging reef structures. Sources: NOAA Coral Reef Conservation Program, Global Coral Reef Alliance."
|
||||||
else:
|
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
|
||||||
|
|
||||||
|
|
||||||
# Define Mock Calculator Tool
|
# Define Mock Calculator Tool
|
||||||
@@ -53,10 +51,11 @@ class CalculatorTool(BaseTool):
|
|||||||
def _run(self, expression: str) -> str:
|
def _run(self, expression: str) -> str:
|
||||||
"""Calculate the result of a mathematical expression."""
|
"""Calculate the result of a mathematical expression."""
|
||||||
try:
|
try:
|
||||||
result = eval(expression, {"__builtins__": {}})
|
# Using eval with restricted builtins for test purposes only
|
||||||
|
result = eval(expression, {"__builtins__": {}}) # noqa: S307
|
||||||
return f"The result of {expression} is {result}"
|
return f"The result of {expression} is {result}"
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
return f"Error calculating {expression}: {str(e)}"
|
return f"Error calculating {expression}: {e!s}"
|
||||||
|
|
||||||
|
|
||||||
# Define a custom response format using Pydantic
|
# Define a custom response format using Pydantic
|
||||||
@@ -148,12 +147,12 @@ def test_lite_agent_with_tools():
|
|||||||
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
|
"What is the population of Tokyo and how many people would that be per square kilometer if Tokyo's area is 2,194 square kilometers?"
|
||||||
)
|
)
|
||||||
|
|
||||||
assert (
|
assert "21 million" in result.raw or "37 million" in result.raw, (
|
||||||
"21 million" in result.raw or "37 million" in result.raw
|
"Agent should find Tokyo's population"
|
||||||
), "Agent should find Tokyo's population"
|
)
|
||||||
assert (
|
assert "per square kilometer" in result.raw, (
|
||||||
"per square kilometer" in result.raw
|
"Agent should calculate population density"
|
||||||
), "Agent should calculate population density"
|
)
|
||||||
|
|
||||||
received_events = []
|
received_events = []
|
||||||
|
|
||||||
@@ -294,6 +293,7 @@ def test_sets_parent_flow_when_inside_flow():
|
|||||||
|
|
||||||
mock_llm = Mock(spec=LLM)
|
mock_llm = Mock(spec=LLM)
|
||||||
mock_llm.call.return_value = "Test response"
|
mock_llm.call.return_value = "Test response"
|
||||||
|
mock_llm.stop = []
|
||||||
|
|
||||||
class MyFlow(Flow):
|
class MyFlow(Flow):
|
||||||
@start()
|
@start()
|
||||||
|
|||||||
Reference in New Issue
Block a user