mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 07:13:00 +00:00
fix: update llm parameter handling in human_feedback function (#4801)
Modified the llm parameter assignment to retrieve the model attribute from llm if it is not a string, ensuring compatibility with different llm types.
This commit is contained in:
@@ -408,7 +408,7 @@ def human_feedback(
|
|||||||
emit=list(emit) if emit else None,
|
emit=list(emit) if emit else None,
|
||||||
default_outcome=default_outcome,
|
default_outcome=default_outcome,
|
||||||
metadata=metadata or {},
|
metadata=metadata or {},
|
||||||
llm=llm if isinstance(llm, str) else None,
|
llm=llm if isinstance(llm, str) else getattr(llm, "model", None),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Determine effective provider:
|
# Determine effective provider:
|
||||||
|
|||||||
@@ -971,6 +971,128 @@ class TestCollapseToOutcomeJsonParsing:
|
|||||||
assert mock_llm.call.call_count == 2
|
assert mock_llm.call.call_count == 2
|
||||||
|
|
||||||
|
|
||||||
|
class TestLLMObjectPreservedInContext:
|
||||||
|
"""Tests that BaseLLM objects have their model string preserved in PendingFeedbackContext."""
|
||||||
|
|
||||||
|
@patch("crewai.flow.flow.crewai_event_bus.emit")
|
||||||
|
def test_basellm_object_model_string_survives_roundtrip(self, mock_emit: MagicMock) -> None:
|
||||||
|
"""Test that when llm is a BaseLLM object, its model string is stored in context
|
||||||
|
so that outcome collapsing works after async pause/resume.
|
||||||
|
|
||||||
|
This is the exact bug: locally the sync path keeps the LLM object in memory,
|
||||||
|
but in production the async path serializes the context and the LLM object was
|
||||||
|
discarded (stored as None), causing resume to skip classification and always
|
||||||
|
fall back to emit[0].
|
||||||
|
"""
|
||||||
|
with tempfile.TemporaryDirectory() as tmpdir:
|
||||||
|
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||||
|
persistence = SQLiteFlowPersistence(db_path)
|
||||||
|
|
||||||
|
# Create a mock BaseLLM object (not a string)
|
||||||
|
mock_llm_obj = MagicMock()
|
||||||
|
mock_llm_obj.model = "gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
class PausingProvider:
|
||||||
|
def __init__(self, persistence: SQLiteFlowPersistence):
|
||||||
|
self.persistence = persistence
|
||||||
|
self.captured_context: PendingFeedbackContext | None = None
|
||||||
|
|
||||||
|
def request_feedback(
|
||||||
|
self, context: PendingFeedbackContext, flow: Flow
|
||||||
|
) -> str:
|
||||||
|
self.captured_context = context
|
||||||
|
self.persistence.save_pending_feedback(
|
||||||
|
flow_uuid=context.flow_id,
|
||||||
|
context=context,
|
||||||
|
state_data=flow.state if isinstance(flow.state, dict) else flow.state.model_dump(),
|
||||||
|
)
|
||||||
|
raise HumanFeedbackPending(context=context)
|
||||||
|
|
||||||
|
provider = PausingProvider(persistence)
|
||||||
|
|
||||||
|
class TestFlow(Flow):
|
||||||
|
result_path: str = ""
|
||||||
|
|
||||||
|
@start()
|
||||||
|
@human_feedback(
|
||||||
|
message="Approve?",
|
||||||
|
emit=["needs_changes", "approved"],
|
||||||
|
llm=mock_llm_obj,
|
||||||
|
default_outcome="approved",
|
||||||
|
provider=provider,
|
||||||
|
)
|
||||||
|
def review(self):
|
||||||
|
return "content for review"
|
||||||
|
|
||||||
|
@listen("approved")
|
||||||
|
def handle_approved(self):
|
||||||
|
self.result_path = "approved"
|
||||||
|
return "Approved!"
|
||||||
|
|
||||||
|
@listen("needs_changes")
|
||||||
|
def handle_changes(self):
|
||||||
|
self.result_path = "needs_changes"
|
||||||
|
return "Changes needed"
|
||||||
|
|
||||||
|
# Phase 1: Start flow (should pause)
|
||||||
|
flow1 = TestFlow(persistence=persistence)
|
||||||
|
result = flow1.kickoff()
|
||||||
|
assert isinstance(result, HumanFeedbackPending)
|
||||||
|
|
||||||
|
# Verify the context stored the model STRING, not None
|
||||||
|
assert provider.captured_context is not None
|
||||||
|
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
# Verify it survives persistence roundtrip
|
||||||
|
flow_id = result.context.flow_id
|
||||||
|
loaded = persistence.load_pending_feedback(flow_id)
|
||||||
|
assert loaded is not None
|
||||||
|
_, loaded_context = loaded
|
||||||
|
assert loaded_context.llm == "gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
# Phase 2: Resume with positive feedback - should use LLM to classify
|
||||||
|
flow2 = TestFlow.from_pending(flow_id, persistence)
|
||||||
|
assert flow2._pending_feedback_context is not None
|
||||||
|
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
|
||||||
|
|
||||||
|
# Mock _collapse_to_outcome to verify it gets called (not skipped)
|
||||||
|
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
|
||||||
|
flow2.resume("this looks good, proceed!")
|
||||||
|
|
||||||
|
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
|
||||||
|
mock_collapse.assert_called_once_with(
|
||||||
|
feedback="this looks good, proceed!",
|
||||||
|
outcomes=["needs_changes", "approved"],
|
||||||
|
llm="gemini/gemini-2.0-flash",
|
||||||
|
)
|
||||||
|
assert flow2.last_human_feedback.outcome == "approved"
|
||||||
|
assert flow2.result_path == "approved"
|
||||||
|
|
||||||
|
def test_string_llm_still_works(self) -> None:
|
||||||
|
"""Test that passing llm as a string still works correctly."""
|
||||||
|
context = PendingFeedbackContext(
|
||||||
|
flow_id="str-llm-test",
|
||||||
|
flow_class="test.Flow",
|
||||||
|
method_name="review",
|
||||||
|
method_output="output",
|
||||||
|
message="Review:",
|
||||||
|
emit=["approved", "rejected"],
|
||||||
|
llm="gpt-4o-mini",
|
||||||
|
)
|
||||||
|
|
||||||
|
serialized = context.to_dict()
|
||||||
|
restored = PendingFeedbackContext.from_dict(serialized)
|
||||||
|
assert restored.llm == "gpt-4o-mini"
|
||||||
|
|
||||||
|
def test_none_llm_when_no_model_attr(self) -> None:
|
||||||
|
"""Test that llm is None when object has no model attribute."""
|
||||||
|
mock_obj = MagicMock(spec=[]) # No attributes
|
||||||
|
|
||||||
|
# Simulate what the decorator does
|
||||||
|
llm_value = mock_obj if isinstance(mock_obj, str) else getattr(mock_obj, "model", None)
|
||||||
|
assert llm_value is None
|
||||||
|
|
||||||
|
|
||||||
class TestAsyncHumanFeedbackEdgeCases:
|
class TestAsyncHumanFeedbackEdgeCases:
|
||||||
"""Edge case tests for async human feedback."""
|
"""Edge case tests for async human feedback."""
|
||||||
|
|
||||||
|
|||||||
Reference in New Issue
Block a user