mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Brandon/eng 266 conversation crew v1 (#1843)
* worked on foundation for new conversational crews. Now going to work on chatting. * core loop should be working and ready for testing. * high level chat working * its alive!! * Added in Joaos feedback to steer crew chats back towards the purpose of the crew * properly return tool call result * accessing crew directly instead of through uv commands * everything is working for conversation now * Fix linting * fix llm_utils.py and other type errors * fix more type errors * fixing type error * More fixing of types * fix failing tests * Fix more failing tests * adding tests. cleaing up pr. * improve * drop old functions * improve type hintings
This commit is contained in:
committed by
GitHub
parent
a2f839fada
commit
8f57753656
@@ -1,21 +1,27 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
import warnings
|
||||
from contextlib import contextmanager
|
||||
from importlib import resources
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
from typing import Any, Dict, List, Optional, Union, cast
|
||||
|
||||
from dotenv import load_dotenv
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
import litellm
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import Choices, get_supported_openai_params
|
||||
from litellm.types.utils import ModelResponse
|
||||
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
@@ -24,6 +30,7 @@ class FilteredStream:
|
||||
|
||||
def write(self, s) -> int:
|
||||
with self._lock:
|
||||
# Filter out extraneous messages from LiteLLM
|
||||
if (
|
||||
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
||||
in s
|
||||
@@ -79,18 +86,18 @@ CONTEXT_WINDOW_USAGE_RATIO = 0.75
|
||||
def suppress_warnings():
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore")
|
||||
warnings.filterwarnings("ignore", message="open_text is deprecated*", category=DeprecationWarning)
|
||||
warnings.filterwarnings(
|
||||
"ignore", message="open_text is deprecated*", category=DeprecationWarning
|
||||
)
|
||||
|
||||
# Redirect stdout and stderr
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream(old_stdout)
|
||||
sys.stderr = FilteredStream(old_stderr)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
@@ -111,13 +118,12 @@ class LLM:
|
||||
logit_bias: Optional[Dict[int, float]] = None,
|
||||
response_format: Optional[Dict[str, Any]] = None,
|
||||
seed: Optional[int] = None,
|
||||
logprobs: Optional[bool] = None,
|
||||
logprobs: Optional[int] = None,
|
||||
top_logprobs: Optional[int] = None,
|
||||
base_url: Optional[str] = None,
|
||||
api_version: Optional[str] = None,
|
||||
api_key: Optional[str] = None,
|
||||
callbacks: List[Any] = [],
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
self.timeout = timeout
|
||||
@@ -139,19 +145,40 @@ class LLM:
|
||||
self.api_key = api_key
|
||||
self.callbacks = callbacks
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
High-level call method that:
|
||||
1) Calls litellm.completion
|
||||
2) Checks for function/tool calls
|
||||
3) If a tool call is found:
|
||||
a) executes the function
|
||||
b) returns the result
|
||||
4) If no tool call, returns the text response
|
||||
|
||||
:param messages: The conversation messages
|
||||
:param tools: Optional list of function schemas for function calling
|
||||
:param callbacks: Optional list of callbacks
|
||||
:param available_functions: A dictionary mapping function_name -> actual Python function
|
||||
:return: Final text response from the LLM or the tool result
|
||||
"""
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 1) Make the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -172,21 +199,58 @@ class LLM:
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
**self.kwargs,
|
||||
"tools": tools, # pass the tool schema
|
||||
}
|
||||
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 2) If no tool calls, return the text response
|
||||
if not tool_calls or not available_functions:
|
||||
return text_response
|
||||
|
||||
# --- 3) Handle the tool call
|
||||
tool_call = tool_calls[0]
|
||||
function_name = tool_call.function.name
|
||||
|
||||
if function_name in available_functions:
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.warning(f"Failed to parse function arguments: {e}")
|
||||
return text_response
|
||||
|
||||
fn = available_functions[function_name]
|
||||
try:
|
||||
# Call the actual tool function
|
||||
result = fn(**function_args)
|
||||
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error executing function '{function_name}': {e}"
|
||||
)
|
||||
return text_response
|
||||
|
||||
else:
|
||||
logging.warning(
|
||||
f"Tool call requested unknown function '{function_name}'"
|
||||
)
|
||||
return text_response
|
||||
|
||||
except Exception as e:
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
raise
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
@@ -205,7 +269,10 @@ class LLM:
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
cutting off messages mid-thread.
|
||||
"""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
@@ -218,6 +285,10 @@ class LLM:
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
with suppress_warnings():
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
@@ -254,15 +325,15 @@ class LLM:
|
||||
success_callbacks = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
callback.strip() for callback in success_callbacks_str.split(",")
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
failure_callbacks = []
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks = [
|
||||
callback.strip() for callback in failure_callbacks_str.split(",")
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
litellm.success_callback = success_callbacks
|
||||
litellm.failure_callback = failure_callbacks
|
||||
|
||||
Reference in New Issue
Block a user