Brandon/eng 266 conversation crew v1 (#1843)

* worked on foundation for new conversational crews. Now going to work on chatting.

* core loop should be working and ready for testing.

* high level chat working

* its alive!!

* Added in Joaos feedback to steer crew chats back towards the purpose of the crew

* properly return tool call result

* accessing crew directly instead of through uv commands

* everything is working for conversation now

* Fix linting

* fix llm_utils.py and other type errors

* fix more type errors

* fixing type error

* More fixing of types

* fix failing tests

* Fix more failing tests

* adding tests. cleaing up pr.

* improve

* drop old functions

* improve type hintings
This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-01-06 16:12:43 -05:00
committed by GitHub
parent a2f839fada
commit 8f57753656
17 changed files with 1195 additions and 1035 deletions

View File

@@ -1,21 +1,27 @@
import json
import logging
import os
import sys
import threading
import warnings
from contextlib import contextmanager
from importlib import resources
from typing import Any, Dict, List, Optional, Union
from typing import Any, Dict, List, Optional, Union, cast
from dotenv import load_dotenv
with warnings.catch_warnings():
warnings.simplefilter("ignore", UserWarning)
import litellm
from litellm import get_supported_openai_params
from litellm import Choices, get_supported_openai_params
from litellm.types.utils import ModelResponse
from crewai.utilities.exceptions.context_window_exceeding_exception import (
LLMContextLengthExceededException,
)
load_dotenv()
class FilteredStream:
def __init__(self, original_stream):
@@ -24,6 +30,7 @@ class FilteredStream:
def write(self, s) -> int:
with self._lock:
# Filter out extraneous messages from LiteLLM
if (
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
in s
@@ -79,18 +86,18 @@ CONTEXT_WINDOW_USAGE_RATIO = 0.75
def suppress_warnings():
with warnings.catch_warnings():
warnings.filterwarnings("ignore")
warnings.filterwarnings("ignore", message="open_text is deprecated*", category=DeprecationWarning)
warnings.filterwarnings(
"ignore", message="open_text is deprecated*", category=DeprecationWarning
)
# Redirect stdout and stderr
old_stdout = sys.stdout
old_stderr = sys.stderr
sys.stdout = FilteredStream(old_stdout)
sys.stderr = FilteredStream(old_stderr)
try:
yield
finally:
# Restore stdout and stderr
sys.stdout = old_stdout
sys.stderr = old_stderr
@@ -111,13 +118,12 @@ class LLM:
logit_bias: Optional[Dict[int, float]] = None,
response_format: Optional[Dict[str, Any]] = None,
seed: Optional[int] = None,
logprobs: Optional[bool] = None,
logprobs: Optional[int] = None,
top_logprobs: Optional[int] = None,
base_url: Optional[str] = None,
api_version: Optional[str] = None,
api_key: Optional[str] = None,
callbacks: List[Any] = [],
**kwargs,
):
self.model = model
self.timeout = timeout
@@ -139,19 +145,40 @@ class LLM:
self.api_key = api_key
self.callbacks = callbacks
self.context_window_size = 0
self.kwargs = kwargs
litellm.drop_params = True
self.set_callbacks(callbacks)
self.set_env_callbacks()
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
def call(
self,
messages: List[Dict[str, str]],
tools: Optional[List[dict]] = None,
callbacks: Optional[List[Any]] = None,
available_functions: Optional[Dict[str, Any]] = None,
) -> str:
"""
High-level call method that:
1) Calls litellm.completion
2) Checks for function/tool calls
3) If a tool call is found:
a) executes the function
b) returns the result
4) If no tool call, returns the text response
:param messages: The conversation messages
:param tools: Optional list of function schemas for function calling
:param callbacks: Optional list of callbacks
:param available_functions: A dictionary mapping function_name -> actual Python function
:return: Final text response from the LLM or the tool result
"""
with suppress_warnings():
if callbacks and len(callbacks) > 0:
self.set_callbacks(callbacks)
try:
# --- 1) Make the completion call
params = {
"model": self.model,
"messages": messages,
@@ -172,21 +199,58 @@ class LLM:
"api_version": self.api_version,
"api_key": self.api_key,
"stream": False,
**self.kwargs,
"tools": tools, # pass the tool schema
}
# Remove None values to avoid passing unnecessary parameters
params = {k: v for k, v in params.items() if v is not None}
response = litellm.completion(**params)
return response["choices"][0]["message"]["content"]
response_message = cast(Choices, cast(ModelResponse, response).choices)[
0
].message
text_response = response_message.content or ""
tool_calls = getattr(response_message, "tool_calls", [])
# --- 2) If no tool calls, return the text response
if not tool_calls or not available_functions:
return text_response
# --- 3) Handle the tool call
tool_call = tool_calls[0]
function_name = tool_call.function.name
if function_name in available_functions:
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.warning(f"Failed to parse function arguments: {e}")
return text_response
fn = available_functions[function_name]
try:
# Call the actual tool function
result = fn(**function_args)
return result
except Exception as e:
logging.error(
f"Error executing function '{function_name}': {e}"
)
return text_response
else:
logging.warning(
f"Tool call requested unknown function '{function_name}'"
)
return text_response
except Exception as e:
if not LLMContextLengthExceededException(
str(e)
)._is_context_limit_error(str(e)):
logging.error(f"LiteLLM call failed: {str(e)}")
raise # Re-raise the exception after logging
raise
def supports_function_calling(self) -> bool:
try:
@@ -205,7 +269,10 @@ class LLM:
return False
def get_context_window_size(self) -> int:
# Only using 75% of the context window size to avoid cutting the message in the middle
"""
Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread.
"""
if self.context_window_size != 0:
return self.context_window_size
@@ -218,6 +285,10 @@ class LLM:
return self.context_window_size
def set_callbacks(self, callbacks: List[Any]):
"""
Attempt to keep a single set of callbacks in litellm by removing old
duplicates and adding new ones.
"""
with suppress_warnings():
callback_types = [type(callback) for callback in callbacks]
for callback in litellm.success_callback[:]:
@@ -254,15 +325,15 @@ class LLM:
success_callbacks = []
if success_callbacks_str:
success_callbacks = [
callback.strip() for callback in success_callbacks_str.split(",")
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
]
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
failure_callbacks = []
if failure_callbacks_str:
failure_callbacks = [
callback.strip() for callback in failure_callbacks_str.split(",")
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
]
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks
litellm.success_callback = success_callbacks
litellm.failure_callback = failure_callbacks