mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Non-streaming working again
This commit is contained in:
@@ -279,6 +279,7 @@ class LLM:
|
|||||||
def _handle_streaming_response(
|
def _handle_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle a streaming response from the LLM.
|
"""Handle a streaming response from the LLM.
|
||||||
@@ -295,13 +296,6 @@ class LLM:
|
|||||||
last_chunk = None
|
last_chunk = None
|
||||||
chunk_count = 0
|
chunk_count = 0
|
||||||
debug_info = []
|
debug_info = []
|
||||||
aggregated_usage = {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
"successful_requests": 0,
|
|
||||||
"cached_prompt_tokens": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- 2) Make sure stream is set to True
|
# --- 2) Make sure stream is set to True
|
||||||
params["stream"] = True
|
params["stream"] = True
|
||||||
@@ -322,11 +316,8 @@ class LLM:
|
|||||||
if isinstance(chunk, ModelResponse):
|
if isinstance(chunk, ModelResponse):
|
||||||
debug_info.append("Chunk is ModelResponse")
|
debug_info.append("Chunk is ModelResponse")
|
||||||
|
|
||||||
# Capture and aggregate usage information from the chunk if available
|
# Get usage information from the chunk
|
||||||
chunk_usage = getattr(chunk, "usage", None)
|
usage_info = getattr(chunk, "usage", None)
|
||||||
if isinstance(chunk_usage, dict):
|
|
||||||
for key in aggregated_usage:
|
|
||||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
|
||||||
|
|
||||||
choices = getattr(chunk, "choices", [])
|
choices = getattr(chunk, "choices", [])
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
@@ -346,22 +337,18 @@ class LLM:
|
|||||||
|
|
||||||
# Handle object-style choices
|
# Handle object-style choices
|
||||||
else:
|
else:
|
||||||
# Try to access delta attribute safely
|
|
||||||
delta = getattr(choice, "delta", None)
|
delta = getattr(choice, "delta", None)
|
||||||
debug_info.append(f"Delta: {delta}")
|
debug_info.append(f"Delta: {delta}")
|
||||||
|
|
||||||
if delta is not None:
|
if delta is not None:
|
||||||
# Try to get content from delta.content
|
|
||||||
if (
|
if (
|
||||||
hasattr(delta, "content")
|
hasattr(delta, "content")
|
||||||
and getattr(delta, "content", None) is not None
|
and getattr(delta, "content", None) is not None
|
||||||
):
|
):
|
||||||
chunk_content = getattr(delta, "content")
|
chunk_content = getattr(delta, "content")
|
||||||
# Some models return delta as a string
|
|
||||||
elif isinstance(delta, str):
|
elif isinstance(delta, str):
|
||||||
chunk_content = delta
|
chunk_content = delta
|
||||||
|
|
||||||
# Add content to response if found
|
|
||||||
if chunk_content:
|
if chunk_content:
|
||||||
full_response += chunk_content
|
full_response += chunk_content
|
||||||
print(f"Chunk content: {chunk_content}")
|
print(f"Chunk content: {chunk_content}")
|
||||||
@@ -377,11 +364,10 @@ class LLM:
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
"No chunks received in streaming response, falling back to non-streaming"
|
"No chunks received in streaming response, falling back to non-streaming"
|
||||||
)
|
)
|
||||||
# Try non-streaming as fallback
|
|
||||||
non_streaming_params = params.copy()
|
non_streaming_params = params.copy()
|
||||||
non_streaming_params["stream"] = False
|
non_streaming_params["stream"] = False
|
||||||
return self._handle_non_streaming_response(
|
return self._handle_non_streaming_response(
|
||||||
non_streaming_params, available_functions
|
non_streaming_params, callbacks, available_functions
|
||||||
)
|
)
|
||||||
|
|
||||||
# --- 5) Handle empty response with chunks
|
# --- 5) Handle empty response with chunks
|
||||||
@@ -389,35 +375,23 @@ class LLM:
|
|||||||
logging.warning(
|
logging.warning(
|
||||||
f"Received {chunk_count} chunks but no content. Debug info: {debug_info}"
|
f"Received {chunk_count} chunks but no content. Debug info: {debug_info}"
|
||||||
)
|
)
|
||||||
if last_chunk is not None:
|
if last_chunk is not None and isinstance(last_chunk, ModelResponse):
|
||||||
# Try to extract any content from the last chunk
|
usage_info = getattr(last_chunk, "usage", None)
|
||||||
if isinstance(last_chunk, ModelResponse):
|
|
||||||
# Capture and aggregate usage information from the last chunk if available
|
|
||||||
chunk_usage = getattr(last_chunk, "usage", None)
|
|
||||||
if isinstance(chunk_usage, dict):
|
|
||||||
for key in aggregated_usage:
|
|
||||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
|
||||||
|
|
||||||
choices = getattr(last_chunk, "choices", [])
|
choices = getattr(last_chunk, "choices", [])
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
|
message = getattr(choice, "message", None)
|
||||||
# Try to get content from message
|
if message is not None and getattr(message, "content", None):
|
||||||
message = getattr(choice, "message", None)
|
full_response = getattr(message, "content")
|
||||||
if message is not None and getattr(
|
logging.info(
|
||||||
message, "content", None
|
f"Extracted content from last chunk message: {full_response}"
|
||||||
):
|
)
|
||||||
full_response = getattr(message, "content")
|
elif getattr(choice, "text", None):
|
||||||
logging.info(
|
full_response = getattr(choice, "text")
|
||||||
f"Extracted content from last chunk message: {full_response}"
|
logging.info(
|
||||||
)
|
f"Extracted text from last chunk: {full_response}"
|
||||||
|
)
|
||||||
# Try to get content from text (some models use this)
|
|
||||||
elif getattr(choice, "text", None):
|
|
||||||
full_response = getattr(choice, "text")
|
|
||||||
logging.info(
|
|
||||||
f"Extracted text from last chunk: {full_response}"
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- 6) If still empty, use a default response
|
# --- 6) If still empty, use a default response
|
||||||
if not full_response.strip():
|
if not full_response.strip():
|
||||||
@@ -426,12 +400,7 @@ class LLM:
|
|||||||
|
|
||||||
# --- 7) Check for tool calls in the final response
|
# --- 7) Check for tool calls in the final response
|
||||||
if isinstance(last_chunk, ModelResponse):
|
if isinstance(last_chunk, ModelResponse):
|
||||||
# Capture and aggregate usage information from the last chunk if available
|
usage_info = getattr(last_chunk, "usage", None)
|
||||||
chunk_usage = getattr(last_chunk, "usage", None)
|
|
||||||
if isinstance(chunk_usage, dict):
|
|
||||||
for key in aggregated_usage:
|
|
||||||
aggregated_usage[key] += chunk_usage.get(key, 0)
|
|
||||||
|
|
||||||
choices = getattr(last_chunk, "choices", [])
|
choices = getattr(last_chunk, "choices", [])
|
||||||
if choices and len(choices) > 0:
|
if choices and len(choices) > 0:
|
||||||
choice = choices[0]
|
choice = choices[0]
|
||||||
@@ -444,21 +413,18 @@ class LLM:
|
|||||||
if tool_result is not None:
|
if tool_result is not None:
|
||||||
return tool_result
|
return tool_result
|
||||||
|
|
||||||
# --- 8) Log token usage if available
|
# --- 8) Log token usage if available in streaming mode
|
||||||
# Use aggregated usage if any tokens were counted
|
# Use usage info from the last chunk if present
|
||||||
if any(value > 0 for value in aggregated_usage.values()):
|
usage_info = getattr(last_chunk, "usage", None) if last_chunk else None
|
||||||
logging.info(
|
if usage_info and self.callbacks and len(self.callbacks) > 0:
|
||||||
f"Aggregated token usage from streaming response: {aggregated_usage}"
|
for callback in self.callbacks:
|
||||||
)
|
if hasattr(callback, "log_success_event"):
|
||||||
if self.callbacks and len(self.callbacks) > 0:
|
callback.log_success_event(
|
||||||
for callback in self.callbacks:
|
kwargs=params,
|
||||||
if hasattr(callback, "log_success_event"):
|
response_obj={"usage": usage_info},
|
||||||
callback.log_success_event(
|
start_time=0,
|
||||||
kwargs=params,
|
end_time=0,
|
||||||
response_obj={"usage": aggregated_usage},
|
)
|
||||||
start_time=0,
|
|
||||||
end_time=0,
|
|
||||||
)
|
|
||||||
|
|
||||||
# --- 9) Emit completion event and return response
|
# --- 9) Emit completion event and return response
|
||||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||||
@@ -468,25 +434,21 @@ class LLM:
|
|||||||
logging.error(
|
logging.error(
|
||||||
f"Error in streaming response: {str(e)}, Debug info: {debug_info}"
|
f"Error in streaming response: {str(e)}, Debug info: {debug_info}"
|
||||||
)
|
)
|
||||||
# If we have any response content, return it instead of failing
|
|
||||||
if full_response.strip():
|
if full_response.strip():
|
||||||
logging.warning(f"Returning partial response despite error: {str(e)}")
|
logging.warning(f"Returning partial response despite error: {str(e)}")
|
||||||
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(full_response, LLMCallType.LLM_CALL)
|
||||||
return full_response
|
return full_response
|
||||||
|
|
||||||
# Try non-streaming as fallback
|
|
||||||
try:
|
try:
|
||||||
logging.warning("Falling back to non-streaming after error")
|
logging.warning("Falling back to non-streaming after error")
|
||||||
non_streaming_params = params.copy()
|
non_streaming_params = params.copy()
|
||||||
non_streaming_params["stream"] = False
|
non_streaming_params["stream"] = False
|
||||||
return self._handle_non_streaming_response(
|
return self._handle_non_streaming_response(
|
||||||
non_streaming_params, available_functions
|
non_streaming_params, callbacks, available_functions
|
||||||
)
|
)
|
||||||
except Exception as fallback_error:
|
except Exception as fallback_error:
|
||||||
logging.error(
|
logging.error(
|
||||||
f"Fallback to non-streaming also failed: {str(fallback_error)}"
|
f"Fallback to non-streaming also failed: {str(fallback_error)}"
|
||||||
)
|
)
|
||||||
# Return a default response as last resort
|
|
||||||
default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
default_response = "I apologize, but I couldn't generate a proper response. Please try again or rephrase your request."
|
||||||
self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL)
|
self._handle_emit_call_events(default_response, LLMCallType.LLM_CALL)
|
||||||
return default_response
|
return default_response
|
||||||
@@ -494,6 +456,7 @@ class LLM:
|
|||||||
def _handle_non_streaming_response(
|
def _handle_non_streaming_response(
|
||||||
self,
|
self,
|
||||||
params: Dict[str, Any],
|
params: Dict[str, Any],
|
||||||
|
callbacks: Optional[List[Any]] = None,
|
||||||
available_functions: Optional[Dict[str, Any]] = None,
|
available_functions: Optional[Dict[str, Any]] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
"""Handle a non-streaming response from the LLM.
|
"""Handle a non-streaming response from the LLM.
|
||||||
@@ -507,16 +470,6 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
# --- 1) Make the completion call
|
# --- 1) Make the completion call
|
||||||
response = litellm.completion(**params)
|
response = litellm.completion(**params)
|
||||||
# Extract usage info – if none is provided, default to zero
|
|
||||||
usage_info = getattr(response, "usage", None)
|
|
||||||
if usage_info is None:
|
|
||||||
usage_info = {
|
|
||||||
"prompt_tokens": 0,
|
|
||||||
"completion_tokens": 0,
|
|
||||||
"total_tokens": 0,
|
|
||||||
"successful_requests": 0,
|
|
||||||
"cached_prompt_tokens": 0,
|
|
||||||
}
|
|
||||||
|
|
||||||
# --- 2) Extract response message and content
|
# --- 2) Extract response message and content
|
||||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||||
@@ -525,18 +478,17 @@ class LLM:
|
|||||||
text_response = response_message.content or ""
|
text_response = response_message.content or ""
|
||||||
|
|
||||||
# --- 3) Handle callbacks with usage info
|
# --- 3) Handle callbacks with usage info
|
||||||
if self.callbacks and len(self.callbacks) > 0:
|
if callbacks and len(callbacks) > 0:
|
||||||
for callback in self.callbacks:
|
for callback in callbacks:
|
||||||
if hasattr(callback, "log_success_event"):
|
if hasattr(callback, "log_success_event"):
|
||||||
logging.info(
|
usage_info = getattr(response, "usage", None)
|
||||||
f"Token usage from non-streaming response: {usage_info}"
|
if usage_info:
|
||||||
)
|
callback.log_success_event(
|
||||||
callback.log_success_event(
|
kwargs=params,
|
||||||
kwargs=params,
|
response_obj={"usage": usage_info},
|
||||||
response_obj={"usage": usage_info},
|
start_time=0,
|
||||||
start_time=0,
|
end_time=0,
|
||||||
end_time=0,
|
)
|
||||||
)
|
|
||||||
|
|
||||||
# --- 4) Check for tool calls
|
# --- 4) Check for tool calls
|
||||||
tool_calls = getattr(response_message, "tool_calls", [])
|
tool_calls = getattr(response_message, "tool_calls", [])
|
||||||
@@ -672,10 +624,12 @@ class LLM:
|
|||||||
|
|
||||||
# --- 7) Make the completion call and handle response
|
# --- 7) Make the completion call and handle response
|
||||||
if self.stream:
|
if self.stream:
|
||||||
return self._handle_streaming_response(params, available_functions)
|
return self._handle_streaming_response(
|
||||||
|
params, callbacks, available_functions
|
||||||
|
)
|
||||||
else:
|
else:
|
||||||
return self._handle_non_streaming_response(
|
return self._handle_non_streaming_response(
|
||||||
params, available_functions
|
params, callbacks, available_functions
|
||||||
)
|
)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|||||||
2095
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
2095
tests/cassettes/test_crew_kickoff_streaming_usage_metrics.yaml
Normal file
File diff suppressed because one or more lines are too long
1713
tests/cassettes/test_crew_kickoff_usage_metrics.yaml
Normal file
1713
tests/cassettes/test_crew_kickoff_usage_metrics.yaml
Normal file
File diff suppressed because one or more lines are too long
@@ -948,7 +948,7 @@ def test_api_calls_throttling(capsys):
|
|||||||
moveon.assert_called()
|
moveon.assert_called()
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
def test_crew_kickoff_usage_metrics():
|
def test_crew_kickoff_usage_metrics():
|
||||||
inputs = [
|
inputs = [
|
||||||
{"topic": "dog"},
|
{"topic": "dog"},
|
||||||
@@ -983,6 +983,41 @@ def test_crew_kickoff_usage_metrics():
|
|||||||
assert result.token_usage.cached_prompt_tokens == 0
|
assert result.token_usage.cached_prompt_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
|
def test_crew_kickoff_streaming_usage_metrics():
|
||||||
|
inputs = [
|
||||||
|
{"topic": "dog"},
|
||||||
|
{"topic": "cat"},
|
||||||
|
{"topic": "apple"},
|
||||||
|
]
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
role="{topic} Researcher",
|
||||||
|
goal="Express hot takes on {topic}.",
|
||||||
|
backstory="You have a lot of experience with {topic}.",
|
||||||
|
llm=LLM(model="gpt-4o", stream=True),
|
||||||
|
)
|
||||||
|
|
||||||
|
task = Task(
|
||||||
|
description="Give me an analysis around {topic}.",
|
||||||
|
expected_output="1 bullet point about {topic} that's under 15 words.",
|
||||||
|
agent=agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Use real LLM calls instead of mocking
|
||||||
|
crew = Crew(agents=[agent], tasks=[task])
|
||||||
|
results = crew.kickoff_for_each(inputs=inputs)
|
||||||
|
|
||||||
|
assert len(results) == len(inputs)
|
||||||
|
for result in results:
|
||||||
|
# Assert that all required keys are in usage_metrics and their values are greater than 0
|
||||||
|
assert result.token_usage.total_tokens > 0
|
||||||
|
assert result.token_usage.prompt_tokens > 0
|
||||||
|
assert result.token_usage.completion_tokens > 0
|
||||||
|
assert result.token_usage.successful_requests > 0
|
||||||
|
assert result.token_usage.cached_prompt_tokens == 0
|
||||||
|
|
||||||
|
|
||||||
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
def test_agents_rpm_is_never_set_if_crew_max_RPM_is_not_set():
|
||||||
agent = Agent(
|
agent = Agent(
|
||||||
role="test role",
|
role="test role",
|
||||||
|
|||||||
Reference in New Issue
Block a user