mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-03 00:02:36 +00:00
core loop should be working and ready for testing.
This commit is contained in:
@@ -1,3 +1,4 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -7,7 +8,7 @@ from contextlib import contextmanager
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import litellm
|
||||
from litellm import get_supported_openai_params
|
||||
from litellm import ModelResponse, get_supported_openai_params
|
||||
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
@@ -21,6 +22,7 @@ class FilteredStream:
|
||||
|
||||
def write(self, s) -> int:
|
||||
with self._lock:
|
||||
# Filter out extraneous messages from LiteLLM
|
||||
if (
|
||||
"Give Feedback / Get Help: https://github.com/BerriAI/litellm/issues/new"
|
||||
in s
|
||||
@@ -80,11 +82,9 @@ def suppress_warnings():
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream(old_stdout)
|
||||
sys.stderr = FilteredStream(old_stderr)
|
||||
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
# Restore stdout and stderr
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
@@ -135,8 +135,10 @@ class LLM:
|
||||
self.context_window_size = 0
|
||||
self.kwargs = kwargs
|
||||
|
||||
# For safety, we disable passing init params to next calls
|
||||
litellm.drop_params = True
|
||||
litellm.set_verbose = False
|
||||
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
@@ -173,8 +175,6 @@ class LLM:
|
||||
Create an LLM instance from a dict.
|
||||
We assume the dict has all relevant keys that match what's in the constructor.
|
||||
"""
|
||||
# We can pop off fields we know, then pass the rest into **kwargs
|
||||
# so that any leftover keys still get passed into the LLM constructor.
|
||||
known_fields = {}
|
||||
known_fields["model"] = data.pop("model", None)
|
||||
known_fields["timeout"] = data.pop("timeout", None)
|
||||
@@ -196,15 +196,37 @@ class LLM:
|
||||
known_fields["api_key"] = data.pop("api_key", None)
|
||||
known_fields["callbacks"] = data.pop("callbacks", None)
|
||||
|
||||
# leftover keys go into kwargs:
|
||||
return cls(**known_fields, **data)
|
||||
|
||||
def call(self, messages: List[Dict[str, str]], callbacks: List[Any] = []) -> str:
|
||||
def call(
|
||||
self,
|
||||
messages: List[Dict[str, str]],
|
||||
tools: Optional[List[dict]] = None,
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""
|
||||
High-level call method that:
|
||||
1) Calls litellm.completion
|
||||
2) Checks for function/tool calls
|
||||
3) If tool calls found:
|
||||
a) executes each function
|
||||
b) appends their output as tool messages
|
||||
c) calls litellm.completion again with the updated messages
|
||||
4) Returns the final text response
|
||||
|
||||
:param messages: The conversation messages
|
||||
:param tools: Optional list of function schemas for function calling
|
||||
:param callbacks: Optional list of callbacks
|
||||
:param available_functions: A dictionary mapping function_name -> actual Python function
|
||||
:return: Final text response from the LLM
|
||||
"""
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
if callbacks:
|
||||
self.set_callbacks(callbacks)
|
||||
|
||||
try:
|
||||
# --- 1) Make first completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": messages,
|
||||
@@ -225,21 +247,71 @@ class LLM:
|
||||
"api_version": self.api_version,
|
||||
"api_key": self.api_key,
|
||||
"stream": False,
|
||||
"tools": tools, # pass the tool schema
|
||||
**self.kwargs,
|
||||
}
|
||||
|
||||
# Remove None values to avoid passing unnecessary parameters
|
||||
# remove None values
|
||||
params = {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
response = litellm.completion(**params)
|
||||
return response["choices"][0]["message"]["content"]
|
||||
response_message = response.choices[0].message
|
||||
text_response = response_message.content or ""
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
|
||||
# --- 2) If no tool calls, we can just return
|
||||
if not tool_calls or not available_functions:
|
||||
return text_response
|
||||
|
||||
# --- 3) We have tool calls and a dictionary of available functions
|
||||
# run them, append output to messages
|
||||
for tool_call in tool_calls:
|
||||
function_name = tool_call.function.name
|
||||
if function_name in available_functions:
|
||||
# parse arguments
|
||||
function_args = {}
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except Exception as e:
|
||||
logging.warning(f"Failed to parse function arguments: {e}")
|
||||
|
||||
fn = available_functions[function_name]
|
||||
# call the actual tool function
|
||||
result = fn(**function_args)
|
||||
|
||||
# append the "tool" response to messages
|
||||
messages.append(
|
||||
{
|
||||
"tool_call_id": tool_call.id,
|
||||
"role": "tool",
|
||||
"name": function_name,
|
||||
"content": str(result),
|
||||
}
|
||||
)
|
||||
else:
|
||||
logging.warning(
|
||||
f"Tool call requested unknown function '{function_name}'"
|
||||
)
|
||||
|
||||
# --- 4) Make a second call so the LLM can incorporate the tool results
|
||||
second_params = dict(params) # copy the same params
|
||||
second_params["messages"] = messages
|
||||
# We'll remove "tools" from second call, or keep it if you want more calls
|
||||
# but typically you keep it in case it wants additional calls
|
||||
second_response = litellm.completion(**second_params)
|
||||
second_msg = second_response.choices[0].message
|
||||
final_response = second_msg.content or ""
|
||||
|
||||
return final_response
|
||||
|
||||
except Exception as e:
|
||||
# check if context length was exceeded, otherwise log
|
||||
if not LLMContextLengthExceededException(
|
||||
str(e)
|
||||
)._is_context_limit_error(str(e)):
|
||||
logging.error(f"LiteLLM call failed: {str(e)}")
|
||||
|
||||
raise # Re-raise the exception after logging
|
||||
# re-raise
|
||||
raise
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
try:
|
||||
@@ -258,7 +330,10 @@ class LLM:
|
||||
return False
|
||||
|
||||
def get_context_window_size(self) -> int:
|
||||
# Only using 75% of the context window size to avoid cutting the message in the middle
|
||||
"""
|
||||
Returns the context window size, using 75% of the maximum to avoid
|
||||
cutting off messages mid-thread.
|
||||
"""
|
||||
if self.context_window_size != 0:
|
||||
return self.context_window_size
|
||||
|
||||
@@ -271,6 +346,10 @@ class LLM:
|
||||
return self.context_window_size
|
||||
|
||||
def set_callbacks(self, callbacks: List[Any]):
|
||||
"""
|
||||
Attempt to keep a single set of callbacks in litellm by removing old
|
||||
duplicates and adding new ones.
|
||||
"""
|
||||
callback_types = [type(callback) for callback in callbacks]
|
||||
for callback in litellm.success_callback[:]:
|
||||
if type(callback) in callback_types:
|
||||
@@ -285,34 +364,19 @@ class LLM:
|
||||
def set_env_callbacks(self):
|
||||
"""
|
||||
Sets the success and failure callbacks for the LiteLLM library from environment variables.
|
||||
|
||||
This method reads the `LITELLM_SUCCESS_CALLBACKS` and `LITELLM_FAILURE_CALLBACKS`
|
||||
environment variables, which should contain comma-separated lists of callback names.
|
||||
It then assigns these lists to `litellm.success_callback` and `litellm.failure_callback`,
|
||||
respectively.
|
||||
|
||||
If the environment variables are not set or are empty, the corresponding callback lists
|
||||
will be set to empty lists.
|
||||
|
||||
Example:
|
||||
LITELLM_SUCCESS_CALLBACKS="langfuse,langsmith"
|
||||
LITELLM_FAILURE_CALLBACKS="langfuse"
|
||||
|
||||
This will set `litellm.success_callback` to ["langfuse", "langsmith"] and
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
"""
|
||||
success_callbacks_str = os.environ.get("LITELLM_SUCCESS_CALLBACKS", "")
|
||||
success_callbacks = []
|
||||
if success_callbacks_str:
|
||||
success_callbacks = [
|
||||
callback.strip() for callback in success_callbacks_str.split(",")
|
||||
cb.strip() for cb in success_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
failure_callbacks_str = os.environ.get("LITELLM_FAILURE_CALLBACKS", "")
|
||||
failure_callbacks = []
|
||||
if failure_callbacks_str:
|
||||
failure_callbacks = [
|
||||
callback.strip() for callback in failure_callbacks_str.split(",")
|
||||
cb.strip() for cb in failure_callbacks_str.split(",") if cb.strip()
|
||||
]
|
||||
|
||||
litellm.success_callback = success_callbacks
|
||||
|
||||
Reference in New Issue
Block a user