From e659c352df0404977a6836b6bf38115b9aaf75e0 Mon Sep 17 00:00:00 2001 From: Lorenze Jay Date: Mon, 24 Mar 2025 13:56:23 -0700 Subject: [PATCH] Refactor Crew class and LLM hierarchy for improved type handling and code clarity - Update Crew class methods to enhance readability with consistent formatting and type hints. - Change LLM class to inherit from BaseLLM for better structure. - Remove unnecessary type checks and streamline tool handling in CrewAgentExecutor. - Adjust BaseLLM to provide default implementations for stop words and context window size methods. - Clean up AISuiteLLM by removing unused methods related to stop words and context window size. --- src/crewai/agents/crew_agent_executor.py | 4 +- src/crewai/crew.py | 52 +++++++++++----- src/crewai/llm.py | 2 +- src/crewai/llms/base_llm.py | 76 ++++++++++++++---------- src/crewai/llms/third_party/ai_suite.py | 9 --- 5 files changed, 82 insertions(+), 61 deletions(-) diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index 0fe61888e..bb17cd095 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -13,7 +13,7 @@ from crewai.agents.parser import ( OutputParserException, ) from crewai.agents.tools_handler import ToolsHandler -from crewai.llm import LLM, BaseLLM +from crewai.llm import BaseLLM from crewai.tools.base_tool import BaseTool from crewai.tools.tool_usage import ToolUsage, ToolUsageErrorException from crewai.utilities import I18N, Printer @@ -61,7 +61,7 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): callbacks: List[Any] = [], ): self._i18n: I18N = I18N() - self.llm: Union[LLM, BaseLLM] = llm + self.llm: BaseLLM = llm self.task = task self.agent = agent self.crew = crew diff --git a/src/crewai/crew.py b/src/crewai/crew.py index c992a80c8..c3310b961 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -805,7 +805,11 @@ class Crew(BaseModel): # Determine which tools to use - task tools take precedence over agent tools tools_for_task = task.tools or agent_to_use.tools or [] # Prepare tools and ensure they're compatible with task execution - tools_for_task = self._prepare_tools(agent_to_use, task, cast(Union[List[Tool], List[BaseTool]], tools_for_task)) + tools_for_task = self._prepare_tools( + agent_to_use, + task, + cast(Union[List[Tool], List[BaseTool]], tools_for_task), + ) self._log_task_start(task, agent_to_use.role) @@ -877,7 +881,9 @@ class Crew(BaseModel): self, agent: BaseAgent, task: Task, tools: Union[List[Tool], List[BaseTool]] ) -> List[BaseTool]: # Add delegation tools if agent allows delegation - if hasattr(agent, "allow_delegation") and getattr(agent, "allow_delegation", False): + if hasattr(agent, "allow_delegation") and getattr( + agent, "allow_delegation", False + ): if self.process == Process.hierarchical: if self.manager_agent: tools = self._update_manager_tools(task, tools) @@ -890,10 +896,16 @@ class Crew(BaseModel): tools = self._add_delegation_tools(task, tools) # Add code execution tools if agent allows code execution - if hasattr(agent, "allow_code_execution") and getattr(agent, "allow_code_execution", False): + if hasattr(agent, "allow_code_execution") and getattr( + agent, "allow_code_execution", False + ): tools = self._add_code_execution_tools(agent, tools) - if agent and hasattr(agent, "multimodal") and getattr(agent, "multimodal", False): + if ( + agent + and hasattr(agent, "multimodal") + and getattr(agent, "multimodal", False) + ): tools = self._add_multimodal_tools(agent, tools) # Return a List[BaseTool] which is compatible with both Task.execute_sync and Task.execute_async @@ -905,7 +917,9 @@ class Crew(BaseModel): return task.agent def _merge_tools( - self, existing_tools: Union[List[Tool], List[BaseTool]], new_tools: Union[List[Tool], List[BaseTool]] + self, + existing_tools: Union[List[Tool], List[BaseTool]], + new_tools: Union[List[Tool], List[BaseTool]], ) -> List[BaseTool]: """Merge new tools into existing tools list, avoiding duplicates by tool name.""" if not new_tools: @@ -923,7 +937,10 @@ class Crew(BaseModel): return cast(List[BaseTool], tools) def _inject_delegation_tools( - self, tools: Union[List[Tool], List[BaseTool]], task_agent: BaseAgent, agents: List[BaseAgent] + self, + tools: Union[List[Tool], List[BaseTool]], + task_agent: BaseAgent, + agents: List[BaseAgent], ) -> List[BaseTool]: if hasattr(task_agent, "get_delegation_tools"): delegation_tools = task_agent.get_delegation_tools(agents) @@ -931,21 +948,27 @@ class Crew(BaseModel): return self._merge_tools(tools, cast(List[BaseTool], delegation_tools)) return cast(List[BaseTool], tools) - def _add_multimodal_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + def _add_multimodal_tools( + self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: if hasattr(agent, "get_multimodal_tools"): multimodal_tools = agent.get_multimodal_tools() # Cast multimodal_tools to the expected type for _merge_tools return self._merge_tools(tools, cast(List[BaseTool], multimodal_tools)) return cast(List[BaseTool], tools) - def _add_code_execution_tools(self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + def _add_code_execution_tools( + self, agent: BaseAgent, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: if hasattr(agent, "get_code_execution_tools"): code_tools = agent.get_code_execution_tools() # Cast code_tools to the expected type for _merge_tools return self._merge_tools(tools, cast(List[BaseTool], code_tools)) return cast(List[BaseTool], tools) - def _add_delegation_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + def _add_delegation_tools( + self, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: agents_for_delegation = [agent for agent in self.agents if agent != task.agent] if len(self.agents) > 1 and len(agents_for_delegation) > 0 and task.agent: if not tools: @@ -961,7 +984,9 @@ class Crew(BaseModel): task_name=task.name, task=task.description, agent=role, status="started" ) - def _update_manager_tools(self, task: Task, tools: Union[List[Tool], List[BaseTool]]) -> List[BaseTool]: + def _update_manager_tools( + self, task: Task, tools: Union[List[Tool], List[BaseTool]] + ) -> List[BaseTool]: if self.manager_agent: if task.agent: tools = self._inject_delegation_tools(tools, task.agent, [task.agent]) @@ -1210,7 +1235,7 @@ class Crew(BaseModel): def test( self, n_iterations: int, - eval_llm: Union[str, InstanceOf[LLM]], + eval_llm: Union[str, InstanceOf[BaseLLM]], inputs: Optional[Dict[str, Any]] = None, ) -> None: """Test and evaluate the Crew with the given inputs for n iterations concurrently using concurrent.futures.""" @@ -1219,11 +1244,6 @@ class Crew(BaseModel): llm_instance = create_llm(eval_llm) if not llm_instance: raise ValueError("Failed to create LLM instance.") - - # Ensure we have an LLM instance (not just BaseLLM) for CrewEvaluator - from crewai.llm import LLM - if not isinstance(llm_instance, LLM): - raise TypeError("CrewEvaluator requires an LLM instance, not a BaseLLM instance.") crewai_event_bus.emit( self, diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 5d0b377e6..68000b0f7 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -165,7 +165,7 @@ class StreamingChoices(TypedDict): finish_reason: Optional[str] -class LLM: +class LLM(BaseLLM): def __init__( self, model: str, diff --git a/src/crewai/llms/base_llm.py b/src/crewai/llms/base_llm.py index c8eef4fc7..d77579400 100644 --- a/src/crewai/llms/base_llm.py +++ b/src/crewai/llms/base_llm.py @@ -1,5 +1,5 @@ from abc import ABC, abstractmethod -from typing import Any, Dict, List, Optional, Union +from typing import Any, Callable, Dict, List, Optional, Union class BaseLLM(ABC): @@ -21,13 +21,12 @@ class BaseLLM(ABC): model: str temperature: Optional[float] = None - stop: Optional[Union[str, List[str]]] = None + stop: Optional[List[str]] = None def __init__( self, model: str, temperature: Optional[float] = None, - stop: Optional[Union[str, List[str]]] = None, ): """Initialize the BaseLLM with default attributes. @@ -39,7 +38,7 @@ class BaseLLM(ABC): """ self.model = model self.temperature = temperature - self.stop = stop + self.stop = [] @abstractmethod def call( @@ -74,42 +73,53 @@ class BaseLLM(ABC): """ pass - @abstractmethod - def supports_function_calling(self) -> bool: - """Check if the LLM supports function calling. - - This method should return True if the LLM implementation supports - function calling (tools), and False otherwise. If this method returns - True, the LLM should be able to handle the 'tools' parameter in the - call() method. - - Returns: - True if the LLM supports function calling, False otherwise. - """ - pass - - @abstractmethod def supports_stop_words(self) -> bool: """Check if the LLM supports stop words. - This method should return True if the LLM implementation supports - stop words, and False otherwise. If this method returns True, the - LLM should respect the 'stop' attribute when generating responses. - Returns: - True if the LLM supports stop words, False otherwise. + bool: True if the LLM supports stop words, False otherwise. """ - pass + return True # Default implementation assumes support for stop words - @abstractmethod def get_context_window_size(self) -> int: - """Get the context window size of the LLM. - - This method should return the maximum number of tokens that the LLM - can process in a single request. This is used by CrewAI to ensure - that messages don't exceed the LLM's context window. + """Get the context window size for the LLM. Returns: - The context window size as an integer. + int: The number of tokens/characters the model can handle. """ - pass + # Default implementation - subclasses should override with model-specific values + return 4096 + + def stream( + self, + messages: Union[str, List[Dict[str, str]]], + stream_callback: Optional[Callable[[str], None]] = None, + tools: Optional[List[dict]] = None, + callbacks: Optional[List[Any]] = None, + available_functions: Optional[Dict[str, Any]] = None, + ) -> str: + """Stream responses from the LLM with optional callbacks for each chunk. + + Args: + messages: Input messages for the LLM. + Can be a string or list of message dictionaries. + stream_callback: Optional callback function that receives each + text chunk as it arrives. + tools: Optional list of tool schemas for function calling. + callbacks: Optional list of callback functions. + available_functions: Optional dict mapping function names to callables. + + Returns: + The complete response as a string (after streaming is complete). + + Raises: + ValueError: If the messages format is invalid. + TimeoutError: If the LLM request times out. + RuntimeError: If the LLM request fails for other reasons. + """ + # Default implementation that doesn't actually stream but calls the callback + # Subclasses should override this with proper streaming implementations + response = self.call(messages, tools, callbacks, available_functions) + if stream_callback: + stream_callback(response) + return response diff --git a/src/crewai/llms/third_party/ai_suite.py b/src/crewai/llms/third_party/ai_suite.py index 42a708c89..78185a081 100644 --- a/src/crewai/llms/third_party/ai_suite.py +++ b/src/crewai/llms/third_party/ai_suite.py @@ -36,12 +36,3 @@ class AISuiteLLM(BaseLLM): def supports_function_calling(self) -> bool: return False - - def supports_stop_words(self) -> bool: - return False - - def get_context_window_size(self): - pass - - def set_callbacks(self, callbacks: List[Any]) -> None: - pass