diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 5823ef7f9..3a4d083d4 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -1,15 +1,12 @@ -import os import shutil import subprocess from typing import Any, Dict, List, Literal, Optional, Union -from litellm import AuthenticationError as LiteLLMAuthenticationError from pydantic import Field, InstanceOf, PrivateAttr, model_validator from crewai.agents import CacheHandler from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.crew_agent_executor import CrewAgentExecutor -from crewai.cli.constants import ENV_VARS, LITELLM_PARAMS from crewai.knowledge.knowledge import Knowledge from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context @@ -262,8 +259,8 @@ class Agent(BaseAgent): } )["output"] except Exception as e: - if isinstance(e, LiteLLMAuthenticationError): - # Do not retry on authentication errors + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors raise e self._times_executed += 1 if self._times_executed > self.max_retry_limit: diff --git a/src/crewai/agents/crew_agent_executor.py b/src/crewai/agents/crew_agent_executor.py index d7bf97795..b9797193c 100644 --- a/src/crewai/agents/crew_agent_executor.py +++ b/src/crewai/agents/crew_agent_executor.py @@ -3,8 +3,6 @@ import re from dataclasses import dataclass from typing import Any, Callable, Dict, List, Optional, Union -from litellm.exceptions import AuthenticationError as LiteLLMAuthenticationError - from crewai.agents.agent_builder.base_agent import BaseAgent from crewai.agents.agent_builder.base_agent_executor_mixin import CrewAgentExecutorMixin from crewai.agents.parser import ( @@ -103,7 +101,12 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): try: formatted_answer = self._invoke_loop() except Exception as e: - raise e + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e + else: + self._handle_unknown_error(e) + raise e if self.ask_for_human_input: formatted_answer = self._handle_human_feedback(formatted_answer) @@ -146,6 +149,9 @@ class CrewAgentExecutor(CrewAgentExecutorMixin): formatted_answer = self._handle_output_parser_exception(e) except Exception as e: + if e.__class__.__module__.startswith("litellm"): + # Do not retry on litellm errors + raise e if self._is_context_length_exceeded(e): self._handle_context_length() continue diff --git a/src/crewai/cli/cli.py b/src/crewai/cli/cli.py index 334759a6d..761cc52ad 100644 --- a/src/crewai/cli/cli.py +++ b/src/crewai/cli/cli.py @@ -350,7 +350,10 @@ def chat(): Start a conversation with the Crew, collecting user-supplied inputs, and using the Chat LLM to generate responses. """ - click.echo("Starting a conversation with the Crew") + click.secho( + "\nStarting a conversation with the Crew\n" "Type 'exit' or Ctrl+C to quit.\n", + ) + run_chat() diff --git a/src/crewai/cli/crew_chat.py b/src/crewai/cli/crew_chat.py index f1695f0a4..cd0da2bb8 100644 --- a/src/crewai/cli/crew_chat.py +++ b/src/crewai/cli/crew_chat.py @@ -1,6 +1,9 @@ import json +import platform import re import sys +import threading +import time from pathlib import Path from typing import Any, Dict, List, Optional, Set, Tuple @@ -18,27 +21,29 @@ from crewai.utilities.llm_utils import create_llm MIN_REQUIRED_VERSION = "0.98.0" -def check_conversational_crews_version(crewai_version: str, pyproject_data: dict) -> bool: +def check_conversational_crews_version( + crewai_version: str, pyproject_data: dict +) -> bool: """ Check if the installed crewAI version supports conversational crews. - + Args: - crewai_version: The current version of crewAI - pyproject_data: Dictionary containing pyproject.toml data - + crewai_version: The current version of crewAI. + pyproject_data: Dictionary containing pyproject.toml data. + Returns: - bool: True if version check passes, False otherwise + bool: True if version check passes, False otherwise. """ try: if version.parse(crewai_version) < version.parse(MIN_REQUIRED_VERSION): click.secho( "You are using an older version of crewAI that doesn't support conversational crews. " "Run 'uv upgrade crewai' to get the latest version.", - fg="red" + fg="red", ) return False except version.InvalidVersion: - click.secho("Invalid crewAI version format detected", fg="red") + click.secho("Invalid crewAI version format detected.", fg="red") return False return True @@ -54,20 +59,42 @@ def run_chat(): if not check_conversational_crews_version(crewai_version, pyproject_data): return + crew, crew_name = load_crew_and_name() chat_llm = initialize_chat_llm(crew) if not chat_llm: return - crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) - crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) - system_message = build_system_message(crew_chat_inputs) - - # Call the LLM to generate the introductory message - introductory_message = chat_llm.call( - messages=[{"role": "system", "content": system_message}] + # Indicate that the crew is being analyzed + click.secho( + "\nAnalyzing crew and required inputs - this may take 3 to 30 seconds " + "depending on the complexity of your crew.", + fg="white", ) - click.secho(f"\nAssistant: {introductory_message}\n", fg="green") + + # Start loading indicator + loading_complete = threading.Event() + loading_thread = threading.Thread(target=show_loading, args=(loading_complete,)) + loading_thread.start() + + try: + crew_chat_inputs = generate_crew_chat_inputs(crew, crew_name, chat_llm) + crew_tool_schema = generate_crew_tool_schema(crew_chat_inputs) + system_message = build_system_message(crew_chat_inputs) + + # Call the LLM to generate the introductory message + introductory_message = chat_llm.call( + messages=[{"role": "system", "content": system_message}] + ) + finally: + # Stop loading indicator + loading_complete.set() + loading_thread.join() + + # Indicate that the analysis is complete + click.secho("\nFinished analyzing crew.\n", fg="white") + + click.secho(f"Assistant: {introductory_message}\n", fg="green") messages = [ {"role": "system", "content": system_message}, @@ -78,15 +105,17 @@ def run_chat(): crew_chat_inputs.crew_name: create_tool_function(crew, messages), } - click.secho( - "\nEntering an interactive chat loop with function-calling.\n" - "Type 'exit' or Ctrl+C to quit.\n", - fg="cyan", - ) - chat_loop(chat_llm, messages, crew_tool_schema, available_functions) +def show_loading(event: threading.Event): + """Display animated loading dots while processing.""" + while not event.is_set(): + print(".", end="", flush=True) + time.sleep(1) + print() + + def initialize_chat_llm(crew: Crew) -> Optional[LLM]: """Initializes the chat LLM and handles exceptions.""" try: @@ -120,7 +149,7 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str: "Please keep your responses concise and friendly. " "If a user asks a question outside the crew's scope, provide a brief answer and remind them of the crew's purpose. " "After calling the tool, be prepared to take user feedback and make adjustments as needed. " - "If you are ever unsure about a user's request or need clarification, ask the user for more information." + "If you are ever unsure about a user's request or need clarification, ask the user for more information. " "Before doing anything else, introduce yourself with a friendly message like: 'Hey! I'm here to help you with [crew's purpose]. Could you please provide me with [inputs] so we can get started?' " "For example: 'Hey! I'm here to help you with uncovering and reporting cutting-edge developments through thorough research and detailed analysis. Could you please provide me with a topic you're interested in? This will help us generate a comprehensive research report and detailed analysis.'" f"\nCrew Name: {crew_chat_inputs.crew_name}" @@ -137,25 +166,33 @@ def create_tool_function(crew: Crew, messages: List[Dict[str, str]]) -> Any: return run_crew_tool_with_messages +def flush_input(): + """Flush any pending input from the user.""" + if platform.system() == "Windows": + # Windows platform + import msvcrt + + while msvcrt.kbhit(): + msvcrt.getch() + else: + # Unix-like platforms (Linux, macOS) + import termios + + termios.tcflush(sys.stdin, termios.TCIFLUSH) + + def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): """Main chat loop for interacting with the user.""" while True: try: - user_input = click.prompt("You", type=str) - if user_input.strip().lower() in ["exit", "quit"]: - click.echo("Exiting chat. Goodbye!") - break + # Flush any pending input before accepting new input + flush_input() - messages.append({"role": "user", "content": user_input}) - final_response = chat_llm.call( - messages=messages, - tools=[crew_tool_schema], - available_functions=available_functions, + user_input = get_user_input() + handle_user_input( + user_input, chat_llm, messages, crew_tool_schema, available_functions ) - messages.append({"role": "assistant", "content": final_response}) - click.secho(f"\nAssistant: {final_response}\n", fg="green") - except KeyboardInterrupt: click.echo("\nExiting chat. Goodbye!") break @@ -164,6 +201,55 @@ def chat_loop(chat_llm, messages, crew_tool_schema, available_functions): break +def get_user_input() -> str: + """Collect multi-line user input with exit handling.""" + click.secho( + "\nYou (type your message below. Press 'Enter' twice when you're done):", + fg="blue", + ) + user_input_lines = [] + while True: + line = input() + if line.strip().lower() == "exit": + return "exit" + if line == "": + break + user_input_lines.append(line) + return "\n".join(user_input_lines) + + +def handle_user_input( + user_input: str, + chat_llm: LLM, + messages: List[Dict[str, str]], + crew_tool_schema: Dict[str, Any], + available_functions: Dict[str, Any], +) -> None: + if user_input.strip().lower() == "exit": + click.echo("Exiting chat. Goodbye!") + return + + if not user_input.strip(): + click.echo("Empty message. Please provide input or type 'exit' to quit.") + return + + messages.append({"role": "user", "content": user_input}) + + # Indicate that assistant is processing + click.echo() + click.secho("Assistant is processing your input. Please wait...", fg="green") + + # Process assistant's response + final_response = chat_llm.call( + messages=messages, + tools=[crew_tool_schema], + available_functions=available_functions, + ) + + messages.append({"role": "assistant", "content": final_response}) + click.secho(f"\nAssistant: {final_response}\n", fg="green") + + def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict: """ Dynamically build a Littellm 'function' schema for the given crew. @@ -358,10 +444,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> ): # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description + lambda m: m.group(1), task.description or "" ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output + lambda m: m.group(1), task.expected_output or "" ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") @@ -372,10 +458,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> or f"{{{input_name}}}" in agent.backstory ): # Replace placeholders with input names - agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) - agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) + agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") + agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") agent_backstory = placeholder_pattern.sub( - lambda m: m.group(1), agent.backstory + lambda m: m.group(1), agent.backstory or "" ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") @@ -416,18 +502,20 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str: for task in crew.tasks: # Replace placeholders with input names task_description = placeholder_pattern.sub( - lambda m: m.group(1), task.description + lambda m: m.group(1), task.description or "" ) expected_output = placeholder_pattern.sub( - lambda m: m.group(1), task.expected_output + lambda m: m.group(1), task.expected_output or "" ) context_texts.append(f"Task Description: {task_description}") context_texts.append(f"Expected Output: {expected_output}") for agent in crew.agents: # Replace placeholders with input names - agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role) - agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal) - agent_backstory = placeholder_pattern.sub(lambda m: m.group(1), agent.backstory) + agent_role = placeholder_pattern.sub(lambda m: m.group(1), agent.role or "") + agent_goal = placeholder_pattern.sub(lambda m: m.group(1), agent.goal or "") + agent_backstory = placeholder_pattern.sub( + lambda m: m.group(1), agent.backstory or "" + ) context_texts.append(f"Agent Role: {agent_role}") context_texts.append(f"Agent Goal: {agent_goal}") context_texts.append(f"Agent Backstory: {agent_backstory}") diff --git a/tests/agent_test.py b/tests/agent_test.py index 3ed51ebde..fda47daaf 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1623,7 +1623,7 @@ def test_litellm_auth_error_handling(): agent=agent, ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), @@ -1639,7 +1639,7 @@ def test_litellm_auth_error_handling(): def test_crew_agent_executor_litellm_auth_error(): """Test that CrewAgentExecutor handles LiteLLM authentication errors by raising them.""" - from litellm import AuthenticationError as LiteLLMAuthenticationError + from litellm.exceptions import AuthenticationError from crewai.agents.tools_handler import ToolsHandler from crewai.utilities import Printer @@ -1672,13 +1672,13 @@ def test_crew_agent_executor_litellm_auth_error(): tools_handler=ToolsHandler(), ) - # Mock the LLM call to raise LiteLLMAuthenticationError + # Mock the LLM call to raise AuthenticationError with ( patch.object(LLM, "call") as mock_llm_call, patch.object(Printer, "print") as mock_printer, - pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"), + pytest.raises(AuthenticationError) as exc_info, ): - mock_llm_call.side_effect = LiteLLMAuthenticationError( + mock_llm_call.side_effect = AuthenticationError( message="Invalid API key", llm_provider="openai", model="gpt-4" ) executor.invoke( @@ -1689,14 +1689,53 @@ def test_crew_agent_executor_litellm_auth_error(): } ) - # Verify error handling + # Verify error handling messages + error_message = f"Error during LLM call: {str(mock_llm_call.side_effect)}" mock_printer.assert_any_call( - content="An unknown error occurred. Please check the details below.", - color="red", - ) - mock_printer.assert_any_call( - content="Error details: litellm.AuthenticationError: Invalid API key", + content=error_message, color="red", ) + # Verify the call was only made once (no retries) mock_llm_call.assert_called_once() + + # Assert that the exception was raised and has the expected attributes + assert exc_info.type is AuthenticationError + assert "Invalid API key".lower() in exc_info.value.message.lower() + assert exc_info.value.llm_provider == "openai" + assert exc_info.value.model == "gpt-4" + + +def test_litellm_anthropic_error_handling(): + """Test that AnthropicError from LiteLLM is handled correctly and not retried.""" + from litellm.llms.anthropic.common_utils import AnthropicError + + # Create an agent with a mocked LLM that uses an Anthropic model + agent = Agent( + role="test role", + goal="test goal", + backstory="test backstory", + llm=LLM(model="claude-3.5-sonnet-20240620"), + max_retry_limit=0, + ) + + # Create a task + task = Task( + description="Test task", + expected_output="Test output", + agent=agent, + ) + + # Mock the LLM call to raise AnthropicError + with ( + patch.object(LLM, "call") as mock_llm_call, + pytest.raises(AnthropicError, match="Test Anthropic error"), + ): + mock_llm_call.side_effect = AnthropicError( + status_code=500, + message="Test Anthropic error", + ) + agent.execute_task(task) + + # Verify the LLM call was only made once (no retries) + mock_llm_call.assert_called_once()