mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 07:08:31 +00:00
Compare commits
2 Commits
1.1.0
...
devin/1761
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
6dcc7ac725 | ||
|
|
4371cf5690 |
@@ -66,11 +66,6 @@ openpyxl = [
|
||||
mem0 = ["mem0ai>=0.1.94"]
|
||||
docling = [
|
||||
"docling>=2.12.0",
|
||||
]
|
||||
aisuite = [
|
||||
"aisuite>=0.1.11",
|
||||
|
||||
|
||||
]
|
||||
qdrant = [
|
||||
"qdrant-client[fastembed]>=1.14.3",
|
||||
@@ -137,13 +132,3 @@ build-backend = "hatchling.build"
|
||||
|
||||
[tool.hatch.version]
|
||||
path = "src/crewai/__init__.py"
|
||||
|
||||
# Declare mutually exclusive extras due to conflicting httpx requirements
|
||||
# a2a requires httpx>=0.28.1, while aisuite requires httpx>=0.27.0,<0.28.0
|
||||
# [tool.uv]
|
||||
# conflicts = [
|
||||
# [
|
||||
# { extra = "a2a" },
|
||||
# { extra = "aisuite" },
|
||||
# ],
|
||||
# ]
|
||||
|
||||
@@ -1,99 +0,0 @@
|
||||
"""AI Suite LLM integration for CrewAI.
|
||||
|
||||
This module provides integration with AI Suite for LLM capabilities.
|
||||
"""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import aisuite as ai # type: ignore
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
|
||||
|
||||
class AISuiteLLM(BaseLLM):
|
||||
"""AI Suite LLM implementation.
|
||||
|
||||
This class provides integration with AI Suite models through the BaseLLM interface.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: float | None = None,
|
||||
stop: list[str] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the AI Suite LLM.
|
||||
|
||||
Args:
|
||||
model: The model identifier for AI Suite.
|
||||
temperature: Optional temperature setting for response generation.
|
||||
stop: Optional list of stop sequences for generation.
|
||||
**kwargs: Additional keyword arguments passed to the AI Suite client.
|
||||
"""
|
||||
super().__init__(model=model, temperature=temperature, stop=stop)
|
||||
self.client = ai.Client()
|
||||
self.kwargs = kwargs
|
||||
|
||||
def call( # type: ignore[override]
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str | Any:
|
||||
"""Call the AI Suite LLM with the given messages.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
from_task: Optional task caller.
|
||||
from_agent: Optional agent caller.
|
||||
|
||||
Returns:
|
||||
The text response from the LLM.
|
||||
"""
|
||||
completion_params = self._prepare_completion_params(messages, tools)
|
||||
response = self.client.chat.completions.create(**completion_params)
|
||||
|
||||
return response.choices[0].message.content
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
tools: list[dict] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
"""Prepare parameters for the AI Suite completion call.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
tools: Optional list of tool schemas.
|
||||
|
||||
Returns:
|
||||
Dictionary of parameters for the completion API.
|
||||
"""
|
||||
params: dict[str, Any] = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
"temperature": self.temperature,
|
||||
"tools": tools,
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
if self.stop:
|
||||
params["stop"] = self.stop
|
||||
|
||||
return params
|
||||
|
||||
@staticmethod
|
||||
def supports_function_calling() -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
Returns:
|
||||
False, as AI Suite does not currently support function calling.
|
||||
"""
|
||||
return False
|
||||
@@ -419,7 +419,7 @@ def handle_context_length(
|
||||
i18n: I18N instance for messages
|
||||
|
||||
Raises:
|
||||
SystemExit: If context length is exceeded and user opts not to summarize
|
||||
LLMContextLengthExceededError: If context length is exceeded and user opts not to summarize
|
||||
"""
|
||||
if respect_context_window:
|
||||
printer.print(
|
||||
@@ -432,7 +432,7 @@ def handle_context_length(
|
||||
content="Context length exceeded. Consider using smaller text or RAG tools from crewai_tools.",
|
||||
color="red",
|
||||
)
|
||||
raise SystemExit(
|
||||
raise LLMContextLengthExceededError(
|
||||
"Context length exceeded and user opted not to summarize. Consider using smaller text or RAG tools from crewai_tools."
|
||||
)
|
||||
|
||||
|
||||
@@ -18,6 +18,9 @@ from crewai.process import Process
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
import pytest
|
||||
|
||||
from crewai import Agent, Crew, Task
|
||||
|
||||
145
lib/crewai/tests/utilities/test_agent_utils.py
Normal file
145
lib/crewai/tests/utilities/test_agent_utils.py
Normal file
@@ -0,0 +1,145 @@
|
||||
"""Test agent utility functions."""
|
||||
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.utilities.agent_utils import handle_context_length
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededError,
|
||||
)
|
||||
from crewai.utilities.i18n import I18N
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
|
||||
def test_handle_context_length_raises_exception_when_respect_context_window_false():
|
||||
"""Test that handle_context_length raises LLMContextLengthExceededError when respect_context_window is False."""
|
||||
# Create mocks for dependencies
|
||||
printer = Printer()
|
||||
i18n = I18N()
|
||||
|
||||
# Create an agent just for its LLM
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
respect_context_window=False,
|
||||
)
|
||||
|
||||
llm = agent.llm
|
||||
|
||||
# Create test messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test message that would exceed context length",
|
||||
}
|
||||
]
|
||||
|
||||
# Set up test parameters
|
||||
respect_context_window = False
|
||||
callbacks = []
|
||||
|
||||
with pytest.raises(LLMContextLengthExceededError) as excinfo:
|
||||
handle_context_length(
|
||||
respect_context_window=respect_context_window,
|
||||
printer=printer,
|
||||
messages=messages,
|
||||
llm=llm,
|
||||
callbacks=callbacks,
|
||||
i18n=i18n,
|
||||
)
|
||||
|
||||
assert "Context length exceeded" in str(excinfo.value)
|
||||
assert "user opted not to summarize" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_handle_context_length_summarizes_when_respect_context_window_true():
|
||||
"""Test that handle_context_length calls summarize_messages when respect_context_window is True."""
|
||||
# Create mocks for dependencies
|
||||
printer = Printer()
|
||||
i18n = I18N()
|
||||
|
||||
# Create an agent just for its LLM
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
respect_context_window=True,
|
||||
)
|
||||
|
||||
llm = agent.llm
|
||||
|
||||
# Create test messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test message that would exceed context length",
|
||||
}
|
||||
]
|
||||
|
||||
# Set up test parameters
|
||||
respect_context_window = True
|
||||
callbacks = []
|
||||
|
||||
with patch("crewai.utilities.agent_utils.summarize_messages") as mock_summarize:
|
||||
handle_context_length(
|
||||
respect_context_window=respect_context_window,
|
||||
printer=printer,
|
||||
messages=messages,
|
||||
llm=llm,
|
||||
callbacks=callbacks,
|
||||
i18n=i18n,
|
||||
)
|
||||
|
||||
mock_summarize.assert_called_once_with(
|
||||
messages=messages, llm=llm, callbacks=callbacks, i18n=i18n
|
||||
)
|
||||
|
||||
|
||||
def test_handle_context_length_does_not_raise_system_exit():
|
||||
"""Test that handle_context_length does NOT raise SystemExit (regression test for issue #3774)."""
|
||||
# Create mocks for dependencies
|
||||
printer = Printer()
|
||||
i18n = I18N()
|
||||
|
||||
# Create an agent just for its LLM
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
respect_context_window=False,
|
||||
)
|
||||
|
||||
llm = agent.llm
|
||||
|
||||
# Create test messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "This is a test message that would exceed context length",
|
||||
}
|
||||
]
|
||||
|
||||
# Set up test parameters
|
||||
respect_context_window = False
|
||||
callbacks = []
|
||||
|
||||
with pytest.raises(Exception) as excinfo:
|
||||
handle_context_length(
|
||||
respect_context_window=respect_context_window,
|
||||
printer=printer,
|
||||
messages=messages,
|
||||
llm=llm,
|
||||
callbacks=callbacks,
|
||||
i18n=i18n,
|
||||
)
|
||||
|
||||
assert not isinstance(excinfo.value, SystemExit), (
|
||||
"handle_context_length should not raise SystemExit. "
|
||||
"It should raise LLMContextLengthExceededError instead."
|
||||
)
|
||||
|
||||
assert isinstance(excinfo.value, LLMContextLengthExceededError), (
|
||||
f"Expected LLMContextLengthExceededError but got {type(excinfo.value).__name__}"
|
||||
)
|
||||
Reference in New Issue
Block a user