Refactor Crew class and LLM hierarchy for improved type handling and code clarity

- Update Crew class methods to enhance readability with consistent formatting and type hints.
- Change LLM class to inherit from BaseLLM for better structure.
- Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor.
- Adjust BaseLLM to provide default implementations for stop words and context window size methods.
- Clean up AISuiteLLM by removing unused methods related to stop words and context window size.
This commit is contained in:
Lorenze Jay
2025-03-24 13:56:23 -07:00
parent 7cae76a631
commit e659c352df
5 changed files with 82 additions and 61 deletions

View File

@@ -13,7 +13,7 @@ from crewai.agents.parser import (
OutputParserException,
)
from crewai.agents.tools_handler import ToolsHandler
from crewai.llm import LLM, BaseLLM
from crewai.llm import BaseLLM
from crewai.tools.base_tool import BaseTool
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
from crewai.utilities import I18N, Printer
@@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
callbacks: List[Any] = [],
):
self._i18n: I18N = I18N()
self.llm: Union[LLM, BaseLLM] = llm
self.llm: BaseLLM = llm
self.task = task
self.agent = agent
self.crew = crew

View File

@@ -805,7 +805,11 @@ class Crew(BaseModel):
# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))
tools_for_task = self._prepare_tools(
agent_to_use,
task,
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
)
self._log_task_start(task, agent_to_use.role)
@@ -877,7 +881,9 @@ class Crew(BaseModel):
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
# Add delegation tools if agent allows delegation
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
if hasattr(agent, "allow_delegation") and getattr(
agent, "allow_delegation", False
):
if self.process == Process.hierarchical:
if self.manager_agent:
tools = self._update_manager_tools(task, tools)
@@ -890,10 +896,16 @@ class Crew(BaseModel):
tools = self._add_delegation_tools(task, tools)
# Add code execution tools if agent allows code execution
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
if hasattr(agent, "allow_code_execution") and getattr(
agent, "allow_code_execution", False
):
tools = self._add_code_execution_tools(agent, tools)
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
if (
agent
and hasattr(agent, "multimodal")
and getattr(agent, "multimodal", False)
):
tools = self._add_multimodal_tools(agent, tools)
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
@@ -905,7 +917,9 @@ class Crew(BaseModel):
return task.agent
def _merge_tools(
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
self,
existing_tools: Union[List[Tool], List[BaseTool]],
new_tools: Union[List[Tool], List[BaseTool]],
) -> List[BaseTool]:
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
if not new_tools:
@@ -923,7 +937,10 @@ class Crew(BaseModel):
return cast(List[BaseTool], tools)
def _inject_delegation_tools(
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
self,
tools: Union[List[Tool], List[BaseTool]],
task_agent: BaseAgent,
agents: List[BaseAgent],
) -> List[BaseTool]:
if hasattr(task_agent, "get_delegation_tools"):
delegation_tools = task_agent.get_delegation_tools(agents)
@@ -931,21 +948,27 @@ class Crew(BaseModel):
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
return cast(List[BaseTool], tools)
def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
def _add_multimodal_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if hasattr(agent, "get_multimodal_tools"):
multimodal_tools = agent.get_multimodal_tools()
# Cast multimodal_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
return cast(List[BaseTool], tools)
def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
def _add_code_execution_tools(
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if hasattr(agent, "get_code_execution_tools"):
code_tools = agent.get_code_execution_tools()
# Cast code_tools to the expected type for _merge_tools
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
return cast(List[BaseTool], tools)
def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
def _add_delegation_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
if not tools:
@@ -961,7 +984,9 @@ class Crew(BaseModel):
task_name=task.name, task=task.description, agent=role, status="started"
)
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
def _update_manager_tools(
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
) -> List[BaseTool]:
if self.manager_agent:
if task.agent:
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
@@ -1210,7 +1235,7 @@ class Crew(BaseModel):
def test(
self,
n_iterations: int,
eval_llm: Union[str, InstanceOf[LLM]],
eval_llm: Union[str, InstanceOf[BaseLLM]],
inputs: Optional[Dict[str, Any]] = None,
) -> None:
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
@@ -1219,11 +1244,6 @@ class Crew(BaseModel):
llm_instance = create_llm(eval_llm)
if not llm_instance:
raise ValueError("Failed to create LLM instance.")
# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
from crewai.llm import LLM
if not isinstance(llm_instance, LLM):
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")
crewai_event_bus.emit(
self,

View File

@@ -165,7 +165,7 @@ class StreamingChoices(TypedDict):
finish_reason: Optional[str]
class LLM:
class LLM(BaseLLM):
def __init__(
self,
model: str,

View File

@@ -1,5 +1,5 @@
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Optional, Union
from typing import Any, Callable, Dict, List, Optional, Union
class BaseLLM(ABC):
@@ -21,13 +21,12 @@ class BaseLLM(ABC):
model: str
temperature: Optional[float] = None
stop: Optional[Union[str, List[str]]] = None
stop: Optional[List[str]] = None
def __init__(
self,
model: str,
temperature: Optional[float] = None,
stop: Optional[Union[str, List[str]]] = None,
):
"""Initialize the BaseLLM with default attributes.
@@ -39,7 +38,7 @@ class BaseLLM(ABC):
"""
self.model = model
self.temperature = temperature
self.stop = stop
self.stop = []
@abstractmethod
def call(
@@ -74,42 +73,53 @@ class BaseLLM(ABC):
"""
pass
@abstractmethod
def supports_function_calling(self) -> bool:
"""Check if the LLM supports function calling.
This method should return True if the LLM implementation supports
function calling (tools), and False otherwise. If this method returns
True, the LLM should be able to handle the 'tools' parameter in the
call() method.
Returns:
True if the LLM supports function calling, False otherwise.
"""
pass
@abstractmethod
def supports_stop_words(self) -> bool:
"""Check if the LLM supports stop words.
This method should return True if the LLM implementation supports
stop words, and False otherwise. If this method returns True, the
LLM should respect the 'stop' attribute when generating responses.
Returns:
True if the LLM supports stop words, False otherwise.
bool: True if the LLM supports stop words, False otherwise.
"""
pass
return True # Default implementation assumes support for stop words
@abstractmethod
def get_context_window_size(self) -> int:
"""Get the context window size of the LLM.
This method should return the maximum number of tokens that the LLM
can process in a single request. This is used by CrewAI to ensure
that messages don't exceed the LLM's context window.
"""Get the context window size for the LLM.
Returns:
The context window size as an integer.
int: The number of tokens/characters the model can handle.
"""
pass
# Default implementation - subclasses should override with model-specific values
return 4096
def stream(
self,
messages: Union[str, List[Dict[str, str]]],
stream_callback: Optional[Callable[[str], None]] = None,
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""Stream responses from the LLM with optional callbacks for each chunk.
Args:
messages: Input messages for the LLM.
Can be a string or list of message dictionaries.
stream_callback: Optional callback function that receives each
text chunk as it arrives.
tools: Optional list of tool schemas for function calling.
callbacks: Optional list of callback functions.
available_functions: Optional dict mapping function names to callables.
Returns:
The complete response as a string (after streaming is complete).
Raises:
ValueError: If the messages format is invalid.
TimeoutError: If the LLM request times out.
RuntimeError: If the LLM request fails for other reasons.
"""
# Default implementation that doesn't actually stream but calls the callback
# Subclasses should override this with proper streaming implementations
response = self.call(messages, tools, callbacks, available_functions)
if stream_callback:
stream_callback(response)
return response

View File

@@ -36,12 +36,3 @@ class AISuiteLLM(BaseLLM):
def supports_function_calling(self) -> bool:
return False
def supports_stop_words(self) -> bool:
return False
def get_context_window_size(self):
pass
def set_callbacks(self, callbacks: List[Any]) -> None:
pass