mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-13 01:58:30 +00:00
Refactor Crew class and LLM hierarchy for improved type handling and code clarity
- Update Crew class methods to enhance readability with consistent formatting and type hints. - Change LLM class to inherit from BaseLLM for better structure. - Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor. - Adjust BaseLLM to provide default implementations for stop words and context window size methods. - Clean up AISuiteLLM by removing unused methods related to stop words and context window size.
This commit is contained in:
@@ -13,7 +13,7 @@ from crewai.agents.parser import (
|
||||
OutputParserException,
|
||||
)
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import BaseLLM
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException
|
||||
from crewai.utilities import I18N, Printer
|
||||
@@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
callbacks: List[Any] = [],
|
||||
):
|
||||
self._i18n: I18N = I18N()
|
||||
self.llm: Union[LLM, BaseLLM] = llm
|
||||
self.llm: BaseLLM = llm
|
||||
self.task = task
|
||||
self.agent = agent
|
||||
self.crew = crew
|
||||
|
||||
@@ -805,7 +805,11 @@ class Crew(BaseModel):
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task))
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
cast(Union[List[Tool], List[BaseTool]], tools_for_task),
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
|
||||
@@ -877,7 +881,9 @@ class Crew(BaseModel):
|
||||
self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False):
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
):
|
||||
if self.process == Process.hierarchical:
|
||||
if self.manager_agent:
|
||||
tools = self._update_manager_tools(task, tools)
|
||||
@@ -890,10 +896,16 @@ class Crew(BaseModel):
|
||||
tools = self._add_delegation_tools(task, tools)
|
||||
|
||||
# Add code execution tools if agent allows code execution
|
||||
if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False):
|
||||
if hasattr(agent, "allow_code_execution") and getattr(
|
||||
agent, "allow_code_execution", False
|
||||
):
|
||||
tools = self._add_code_execution_tools(agent, tools)
|
||||
|
||||
if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False):
|
||||
if (
|
||||
agent
|
||||
and hasattr(agent, "multimodal")
|
||||
and getattr(agent, "multimodal", False)
|
||||
):
|
||||
tools = self._add_multimodal_tools(agent, tools)
|
||||
|
||||
# Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async
|
||||
@@ -905,7 +917,9 @@ class Crew(BaseModel):
|
||||
return task.agent
|
||||
|
||||
def _merge_tools(
|
||||
self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]]
|
||||
self,
|
||||
existing_tools: Union[List[Tool], List[BaseTool]],
|
||||
new_tools: Union[List[Tool], List[BaseTool]],
|
||||
) -> List[BaseTool]:
|
||||
"""Merge new tools into existing tools list, avoiding duplicates by tool name."""
|
||||
if not new_tools:
|
||||
@@ -923,7 +937,10 @@ class Crew(BaseModel):
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _inject_delegation_tools(
|
||||
self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent]
|
||||
self,
|
||||
tools: Union[List[Tool], List[BaseTool]],
|
||||
task_agent: BaseAgent,
|
||||
agents: List[BaseAgent],
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
@@ -931,21 +948,27 @@ class Crew(BaseModel):
|
||||
return self._merge_tools(tools, cast(List[BaseTool], delegation_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _add_multimodal_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_multimodal_tools"):
|
||||
multimodal_tools = agent.get_multimodal_tools()
|
||||
# Cast multimodal_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _add_code_execution_tools(
|
||||
self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(List[BaseTool], code_tools))
|
||||
return cast(List[BaseTool], tools)
|
||||
|
||||
def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _add_delegation_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
agents_for_delegation = [agent for agent in self.agents if agent != task.agent]
|
||||
if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent:
|
||||
if not tools:
|
||||
@@ -961,7 +984,9 @@ class Crew(BaseModel):
|
||||
task_name=task.name, task=task.description, agent=role, status="started"
|
||||
)
|
||||
|
||||
def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]:
|
||||
def _update_manager_tools(
|
||||
self, task: Task, tools: Union[List[Tool], List[BaseTool]]
|
||||
) -> List[BaseTool]:
|
||||
if self.manager_agent:
|
||||
if task.agent:
|
||||
tools = self._inject_delegation_tools(tools, task.agent, [task.agent])
|
||||
@@ -1210,7 +1235,7 @@ class Crew(BaseModel):
|
||||
def test(
|
||||
self,
|
||||
n_iterations: int,
|
||||
eval_llm: Union[str, InstanceOf[LLM]],
|
||||
eval_llm: Union[str, InstanceOf[BaseLLM]],
|
||||
inputs: Optional[Dict[str, Any]] = None,
|
||||
) -> None:
|
||||
"""Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures."""
|
||||
@@ -1219,11 +1244,6 @@ class Crew(BaseModel):
|
||||
llm_instance = create_llm(eval_llm)
|
||||
if not llm_instance:
|
||||
raise ValueError("Failed to create LLM instance.")
|
||||
|
||||
# Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator
|
||||
from crewai.llm import LLM
|
||||
if not isinstance(llm_instance, LLM):
|
||||
raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.")
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -165,7 +165,7 @@ class StreamingChoices(TypedDict):
|
||||
finish_reason: Optional[str]
|
||||
|
||||
|
||||
class LLM:
|
||||
class LLM(BaseLLM):
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Callable, Dict, List, Optional, Union
|
||||
|
||||
|
||||
class BaseLLM(ABC):
|
||||
@@ -21,13 +21,12 @@ class BaseLLM(ABC):
|
||||
|
||||
model: str
|
||||
temperature: Optional[float] = None
|
||||
stop: Optional[Union[str, List[str]]] = None
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
temperature: Optional[float] = None,
|
||||
stop: Optional[Union[str, List[str]]] = None,
|
||||
):
|
||||
"""Initialize the BaseLLM with default attributes.
|
||||
|
||||
@@ -39,7 +38,7 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
self.model = model
|
||||
self.temperature = temperature
|
||||
self.stop = stop
|
||||
self.stop = []
|
||||
|
||||
@abstractmethod
|
||||
def call(
|
||||
@@ -74,42 +73,53 @@ class BaseLLM(ABC):
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the LLM supports function calling.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
function calling (tools), and False otherwise. If this method returns
|
||||
True, the LLM should be able to handle the 'tools' parameter in the
|
||||
call() method.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports function calling, False otherwise.
|
||||
"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def supports_stop_words(self) -> bool:
|
||||
"""Check if the LLM supports stop words.
|
||||
|
||||
This method should return True if the LLM implementation supports
|
||||
stop words, and False otherwise. If this method returns True, the
|
||||
LLM should respect the 'stop' attribute when generating responses.
|
||||
|
||||
Returns:
|
||||
True if the LLM supports stop words, False otherwise.
|
||||
bool: True if the LLM supports stop words, False otherwise.
|
||||
"""
|
||||
pass
|
||||
return True # Default implementation assumes support for stop words
|
||||
|
||||
@abstractmethod
|
||||
def get_context_window_size(self) -> int:
|
||||
"""Get the context window size of the LLM.
|
||||
|
||||
This method should return the maximum number of tokens that the LLM
|
||||
can process in a single request. This is used by CrewAI to ensure
|
||||
that messages don't exceed the LLM's context window.
|
||||
"""Get the context window size for the LLM.
|
||||
|
||||
Returns:
|
||||
The context window size as an integer.
|
||||
int: The number of tokens/characters the model can handle.
|
||||
"""
|
||||
pass
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return 4096
|
||||
|
||||
def stream(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
stream_callback: Optional[Callable[[str], None]] = None,
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Stream responses from the LLM with optional callbacks for each chunk.
|
||||
|
||||
Args:
|
||||
messages: Input messages for the LLM.
|
||||
Can be a string or list of message dictionaries.
|
||||
stream_callback: Optional callback function that receives each
|
||||
text chunk as it arrives.
|
||||
tools: Optional list of tool schemas for function calling.
|
||||
callbacks: Optional list of callback functions.
|
||||
available_functions: Optional dict mapping function names to callables.
|
||||
|
||||
Returns:
|
||||
The complete response as a string (after streaming is complete).
|
||||
|
||||
Raises:
|
||||
ValueError: If the messages format is invalid.
|
||||
TimeoutError: If the LLM request times out.
|
||||
RuntimeError: If the LLM request fails for other reasons.
|
||||
"""
|
||||
# Default implementation that doesn't actually stream but calls the callback
|
||||
# Subclasses should override this with proper streaming implementations
|
||||
response = self.call(messages, tools, callbacks, available_functions)
|
||||
if stream_callback:
|
||||
stream_callback(response)
|
||||
return response
|
||||
|
||||
9
src/crewai/llms/third_party/ai_suite.py
vendored
9
src/crewai/llms/third_party/ai_suite.py
vendored
@@ -36,12 +36,3 @@ class AISuiteLLM(BaseLLM):
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
return False
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
return False
|
||||
|
||||
def get_context_window_size(self):
|
||||
pass
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]) -> None:
|
||||
pass
|
||||
|
||||
Reference in New Issue
Block a user