mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Non-streaming working again
This commit is contained in:
@@ -279,6 +279,7 @@ class LLM:
|
||||
def _handle_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Handle a streaming response from the LLM.
|
||||
@@ -295,13 +296,6 @@ class LLM:
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
debug_info = []
|
||||
aggregated_usage = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
# --- 2) Make sure stream is set to True
|
||||
params["stream"] = True
|
||||
@@ -322,11 +316,8 @@ class LLM:
|
||||
if isinstance(chunk, ModelResponse):
|
||||
debug_info.append("Chunk is ModelResponse")
|
||||
|
||||
# Capture and aggregate usage information from the chunk if available
|
||||
chunk_usage = getattr(chunk, "usage", None)
|
||||
if isinstance(chunk_usage, dict):
|
||||
for key in aggregated_usage:
|
||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
||||
# Get usage information from the chunk
|
||||
usage_info = getattr(chunk, "usage", None)
|
||||
|
||||
choices = getattr(chunk, "choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
@@ -346,22 +337,18 @@ class LLM:
|
||||
|
||||
# Handle object-style choices
|
||||
else:
|
||||
# Try to access delta attribute safely
|
||||
delta = getattr(choice, "delta", None)
|
||||
debug_info.append(f"Delta: {delta}")
|
||||
|
||||
if delta is not None:
|
||||
# Try to get content from delta.content
|
||||
if (
|
||||
hasattr(delta, "content")
|
||||
and getattr(delta, "content", None) is not None
|
||||
):
|
||||
chunk_content = getattr(delta, "content")
|
||||
# Some models return delta as a string
|
||||
elif isinstance(delta, str):
|
||||
chunk_content = delta
|
||||
|
||||
# Add content to response if found
|
||||
if chunk_content:
|
||||
full_response += chunk_content
|
||||
print(f"Chunk content: {chunk_content}")
|
||||
@@ -377,11 +364,10 @@ class LLM:
|
||||
logging.warning(
|
||||
"No chunks received in streaming response, falling back to non-streaming"
|
||||
)
|
||||
# Try non-streaming as fallback
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, available_functions
|
||||
non_streaming_params, callbacks, available_functions
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
@@ -389,35 +375,23 @@ class LLM:
|
||||
logging.warning(
|
||||
f"Received {chunk_count} chunks but no content. Debug info: {debug_info}"
|
||||
)
|
||||
if last_chunk is not None:
|
||||
# Try to extract any content from the last chunk
|
||||
if isinstance(last_chunk, ModelResponse):
|
||||
# Capture and aggregate usage information from the last chunk if available
|
||||
chunk_usage = getattr(last_chunk, "usage", None)
|
||||
if isinstance(chunk_usage, dict):
|
||||
for key in aggregated_usage:
|
||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
||||
if last_chunk is not None and isinstance(last_chunk, ModelResponse):
|
||||
usage_info = getattr(last_chunk, "usage", None)
|
||||
|
||||
choices = getattr(last_chunk, "choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
# Try to get content from message
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None and getattr(
|
||||
message, "content", None
|
||||
):
|
||||
full_response = getattr(message, "content")
|
||||
logging.info(
|
||||
f"Extracted content from last chunk message: {full_response}"
|
||||
)
|
||||
|
||||
# Try to get content from text (some models use this)
|
||||
elif getattr(choice, "text", None):
|
||||
full_response = getattr(choice, "text")
|
||||
logging.info(
|
||||
f"Extracted text from last chunk: {full_response}"
|
||||
)
|
||||
choices = getattr(last_chunk, "choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
message = getattr(choice, "message", None)
|
||||
if message is not None and getattr(message, "content", None):
|
||||
full_response = getattr(message, "content")
|
||||
logging.info(
|
||||
f"Extracted content from last chunk message: {full_response}"
|
||||
)
|
||||
elif getattr(choice, "text", None):
|
||||
full_response = getattr(choice, "text")
|
||||
logging.info(
|
||||
f"Extracted text from last chunk: {full_response}"
|
||||
)
|
||||
|
||||
# --- 6) If still empty, use a default response
|
||||
if not full_response.strip():
|
||||
@@ -426,12 +400,7 @@ class LLM:
|
||||
|
||||
# --- 7) Check for tool calls in the final response
|
||||
if isinstance(last_chunk, ModelResponse):
|
||||
# Capture and aggregate usage information from the last chunk if available
|
||||
chunk_usage = getattr(last_chunk, "usage", None)
|
||||
if isinstance(chunk_usage, dict):
|
||||
for key in aggregated_usage:
|
||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
||||
|
||||
usage_info = getattr(last_chunk, "usage", None)
|
||||
choices = getattr(last_chunk, "choices", [])
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
@@ -444,21 +413,18 @@ class LLM:
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 8) Log token usage if available
|
||||
# Use aggregated usage if any tokens were counted
|
||||
if any(value > 0 for value in aggregated_usage.values()):
|
||||
logging.info(
|
||||
f"Aggregated token usage from streaming response: {aggregated_usage}"
|
||||
)
|
||||
if self.callbacks and len(self.callbacks) > 0:
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": aggregated_usage},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
# --- 8) Log token usage if available in streaming mode
|
||||
# Use usage info from the last chunk if present
|
||||
usage_info = getattr(last_chunk, "usage", None) if last_chunk else None
|
||||
if usage_info and self.callbacks and len(self.callbacks) > 0:
|
||||
for callback in self.callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
# --- 9) Emit completion event and return response
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
@@ -468,25 +434,21 @@ class LLM:
|
||||
logging.error(
|
||||
f"Error in streaming response: {str(e)}, Debug info: {debug_info}"
|
||||
)
|
||||
# If we have any response content, return it instead of failing
|
||||
if full_response.strip():
|
||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||
return full_response
|
||||
|
||||
# Try non-streaming as fallback
|
||||
try:
|
||||
logging.warning("Falling back to non-streaming after error")
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params, available_functions
|
||||
non_streaming_params, callbacks, available_functions
|
||||
)
|
||||
except Exception as fallback_error:
|
||||
logging.error(
|
||||
f"Fallback to non-streaming also failed: {str(fallback_error)}"
|
||||
)
|
||||
# Return a default response as last resort
|
||||
default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||
self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL)
|
||||
return default_response
|
||||
@@ -494,6 +456,7 @@ class LLM:
|
||||
def _handle_non_streaming_response(
|
||||
self,
|
||||
params: Dict[str, Any],
|
||||
callbacks: Optional[List[Any]] = None,
|
||||
available_functions: Optional[Dict[str, Any]] = None,
|
||||
) -> str:
|
||||
"""Handle a non-streaming response from the LLM.
|
||||
@@ -507,16 +470,6 @@ class LLM:
|
||||
"""
|
||||
# --- 1) Make the completion call
|
||||
response = litellm.completion(**params)
|
||||
# Extract usage info – if none is provided, default to zero
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info is None:
|
||||
usage_info = {
|
||||
"prompt_tokens": 0,
|
||||
"completion_tokens": 0,
|
||||
"total_tokens": 0,
|
||||
"successful_requests": 0,
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
# --- 2) Extract response message and content
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
@@ -525,18 +478,17 @@ class LLM:
|
||||
text_response = response_message.content or ""
|
||||
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if self.callbacks and len(self.callbacks) > 0:
|
||||
for callback in self.callbacks:
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if hasattr(callback, "log_success_event"):
|
||||
logging.info(
|
||||
f"Token usage from non-streaming response: {usage_info}"
|
||||
)
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
usage_info = getattr(response, "usage", None)
|
||||
if usage_info:
|
||||
callback.log_success_event(
|
||||
kwargs=params,
|
||||
response_obj={"usage": usage_info},
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = getattr(response_message, "tool_calls", [])
|
||||
@@ -672,10 +624,12 @@ class LLM:
|
||||
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(params, available_functions)
|
||||
return self._handle_streaming_response(
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
else:
|
||||
return self._handle_non_streaming_response(
|
||||
params, available_functions
|
||||
params, callbacks, available_functions
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
2095
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
2095
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
File diff suppressed because one or more lines are too long
1713
tests/cassettes/test_crew_kickoff_usage_metrics.yaml
Normal file
1713
tests/cassettes/test_crew_kickoff_usage_metrics.yaml
Normal file
File diff suppressed because one or more lines are too long
@@ -948,7 +948,7 @@ def test_api_calls_throttling(capsys):
|
||||
moveon.assert_called()
|
||||
|
||||
|
||||
# @pytest.mark.vcr(filter_headers=["authorization"])
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_usage_metrics():
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
@@ -983,6 +983,41 @@ def test_crew_kickoff_usage_metrics():
|
||||
assert result.token_usage.cached_prompt_tokens == 0
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_crew_kickoff_streaming_usage_metrics():
|
||||
inputs = [
|
||||
{"topic": "dog"},
|
||||
{"topic": "cat"},
|
||||
{"topic": "apple"},
|
||||
]
|
||||
|
||||
agent = Agent(
|
||||
role="{topic} Researcher",
|
||||
goal="Express hot takes on {topic}.",
|
||||
backstory="You have a lot of experience with {topic}.",
|
||||
llm=LLM(model="gpt-4o", stream=True),
|
||||
)
|
||||
|
||||
task = Task(
|
||||
description="Give me an analysis around {topic}.",
|
||||
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Use real LLM calls instead of mocking
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
results = crew.kickoff_for_each(inputs=inputs)
|
||||
|
||||
assert len(results) == len(inputs)
|
||||
for result in results:
|
||||
# Assert that all required keys are in usage_metrics and their values are greater than 0
|
||||
assert result.token_usage.total_tokens > 0
|
||||
assert result.token_usage.prompt_tokens > 0
|
||||
assert result.token_usage.completion_tokens > 0
|
||||
assert result.token_usage.successful_requests > 0
|
||||
assert result.token_usage.cached_prompt_tokens == 0
|
||||
|
||||
|
||||
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
Reference in New Issue
Block a user