fix: update llm parameter handling in human_feedback function (#4801)

Modified the llm parameter assignment to retrieve the model attribute from llm if it is not a string, ensuring compatibility with different llm types.
This commit is contained in:
João Moura
2026-03-10 10:27:09 -07:00
committed by GitHub
parent d9f6e2222f
commit f070ce8abd
2 changed files with 123 additions and 1 deletions

View File

@@ -408,7 +408,7 @@ def human_feedback(
emit=list(emit) if emit else None,
default_outcome=default_outcome,
metadata=metadata or {},
llm=llm if isinstance(llm, str) else None,
llm=llm if isinstance(llm, str) else getattr(llm, "model", None),
)
# Determine effective provider:

View File

@@ -971,6 +971,128 @@ class TestCollapseToOutcomeJsonParsing:
assert mock_llm.call.call_count == 2
class TestLLMObjectPreservedInContext:
"""Tests that BaseLLM objects have their model string preserved in PendingFeedbackContext."""
@patch("crewai.flow.flow.crewai_event_bus.emit")
def test_basellm_object_model_string_survives_roundtrip(self, mock_emit: MagicMock) -> None:
"""Test that when llm is a BaseLLM object, its model string is stored in context
so that outcome collapsing works after async pause/resume.
This is the exact bug: locally the sync path keeps the LLM object in memory,
but in production the async path serializes the context and the LLM object was
discarded (stored as None), causing resume to skip classification and always
fall back to emit[0].
"""
with tempfile.TemporaryDirectory() as tmpdir:
db_path = os.path.join(tmpdir, "test_flows.db")
persistence = SQLiteFlowPersistence(db_path)
# Create a mock BaseLLM object (not a string)
mock_llm_obj = MagicMock()
mock_llm_obj.model = "gemini/gemini-2.0-flash"
class PausingProvider:
def __init__(self, persistence: SQLiteFlowPersistence):
self.persistence = persistence
self.captured_context: PendingFeedbackContext | None = None
def request_feedback(
self, context: PendingFeedbackContext, flow: Flow
) -> str:
self.captured_context = context
self.persistence.save_pending_feedback(
flow_uuid=context.flow_id,
context=context,
state_data=flow.state if isinstance(flow.state, dict) else flow.state.model_dump(),
)
raise HumanFeedbackPending(context=context)
provider = PausingProvider(persistence)
class TestFlow(Flow):
result_path: str = ""
@start()
@human_feedback(
message="Approve?",
emit=["needs_changes", "approved"],
llm=mock_llm_obj,
default_outcome="approved",
provider=provider,
)
def review(self):
return "content for review"
@listen("approved")
def handle_approved(self):
self.result_path = "approved"
return "Approved!"
@listen("needs_changes")
def handle_changes(self):
self.result_path = "needs_changes"
return "Changes needed"
# Phase 1: Start flow (should pause)
flow1 = TestFlow(persistence=persistence)
result = flow1.kickoff()
assert isinstance(result, HumanFeedbackPending)
# Verify the context stored the model STRING, not None
assert provider.captured_context is not None
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
# Verify it survives persistence roundtrip
flow_id = result.context.flow_id
loaded = persistence.load_pending_feedback(flow_id)
assert loaded is not None
_, loaded_context = loaded
assert loaded_context.llm == "gemini/gemini-2.0-flash"
# Phase 2: Resume with positive feedback - should use LLM to classify
flow2 = TestFlow.from_pending(flow_id, persistence)
assert flow2._pending_feedback_context is not None
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
# Mock _collapse_to_outcome to verify it gets called (not skipped)
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
flow2.resume("this looks good, proceed!")
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
mock_collapse.assert_called_once_with(
feedback="this looks good, proceed!",
outcomes=["needs_changes", "approved"],
llm="gemini/gemini-2.0-flash",
)
assert flow2.last_human_feedback.outcome == "approved"
assert flow2.result_path == "approved"
def test_string_llm_still_works(self) -> None:
"""Test that passing llm as a string still works correctly."""
context = PendingFeedbackContext(
flow_id="str-llm-test",
flow_class="test.Flow",
method_name="review",
method_output="output",
message="Review:",
emit=["approved", "rejected"],
llm="gpt-4o-mini",
)
serialized = context.to_dict()
restored = PendingFeedbackContext.from_dict(serialized)
assert restored.llm == "gpt-4o-mini"
def test_none_llm_when_no_model_attr(self) -> None:
"""Test that llm is None when object has no model attribute."""
mock_obj = MagicMock(spec=[]) # No attributes
# Simulate what the decorator does
llm_value = mock_obj if isinstance(mock_obj, str) else getattr(mock_obj, "model", None)
assert llm_value is None
class TestAsyncHumanFeedbackEdgeCases:
"""Edge case tests for async human feedback."""