diff --git a/lib/crewai/tests/test_async_human_feedback.py b/lib/crewai/tests/test_async_human_feedback.py index 3fc222387..a72147213 100644 --- a/lib/crewai/tests/test_async_human_feedback.py +++ b/lib/crewai/tests/test_async_human_feedback.py @@ -988,11 +988,9 @@ class TestLLMObjectPreservedInContext: db_path = os.path.join(tmpdir, "test_flows.db") persistence = SQLiteFlowPersistence(db_path) - # Create a mock BaseLLM object (not a string) - # Simulates LLM(model="gemini-2.0-flash", provider="gemini") - mock_llm_obj = MagicMock() - mock_llm_obj.model = "gemini-2.0-flash" - mock_llm_obj.provider = "gemini" + # Create a real LLM object (not a string) + from crewai.llm import LLM + mock_llm_obj = LLM(model="gemini-2.0-flash", provider="gemini") class PausingProvider: def __init__(self, persistence: SQLiteFlowPersistence): @@ -1041,32 +1039,37 @@ class TestLLMObjectPreservedInContext: result = flow1.kickoff() assert isinstance(result, HumanFeedbackPending) - # Verify the context stored the model STRING, not None + # Verify the context stored the model config dict, not None assert provider.captured_context is not None - assert provider.captured_context.llm == "gemini/gemini-2.0-flash" + assert isinstance(provider.captured_context.llm, dict) + assert provider.captured_context.llm["model"] == "gemini/gemini-2.0-flash" # Verify it survives persistence roundtrip flow_id = result.context.flow_id loaded = persistence.load_pending_feedback(flow_id) assert loaded is not None _, loaded_context = loaded - assert loaded_context.llm == "gemini/gemini-2.0-flash" + assert isinstance(loaded_context.llm, dict) + assert loaded_context.llm["model"] == "gemini/gemini-2.0-flash" # Phase 2: Resume with positive feedback - should use LLM to classify flow2 = TestFlow.from_pending(flow_id, persistence) assert flow2._pending_feedback_context is not None - assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash" + assert isinstance(flow2._pending_feedback_context.llm, dict) + assert flow2._pending_feedback_context.llm["model"] == "gemini/gemini-2.0-flash" # Mock _collapse_to_outcome to verify it gets called (not skipped) with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse: flow2.resume("this looks good, proceed!") # The key assertion: _collapse_to_outcome was called (not skipped due to llm=None) - mock_collapse.assert_called_once_with( - feedback="this looks good, proceed!", - outcomes=["needs_changes", "approved"], - llm="gemini/gemini-2.0-flash", - ) + mock_collapse.assert_called_once() + call_kwargs = mock_collapse.call_args + assert call_kwargs.kwargs["feedback"] == "this looks good, proceed!" + assert call_kwargs.kwargs["outcomes"] == ["needs_changes", "approved"] + # LLM should be a live object (from _hf_llm) or reconstructed, not None + assert call_kwargs.kwargs["llm"] is not None + assert getattr(call_kwargs.kwargs["llm"], "model", None) == "gemini-2.0-flash" assert flow2.last_human_feedback.outcome == "approved" assert flow2.result_path == "approved" @@ -1096,23 +1099,25 @@ class TestLLMObjectPreservedInContext: def test_provider_prefix_added_to_bare_model(self) -> None: """Test that provider prefix is added when model has no slash.""" from crewai.flow.human_feedback import _serialize_llm_for_context + from crewai.llm import LLM - mock_obj = MagicMock() - mock_obj.model = "gemini-3-flash-preview" - mock_obj.provider = "gemini" - assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-3-flash-preview" + llm = LLM(model="gemini-2.0-flash", provider="gemini") + result = _serialize_llm_for_context(llm) + assert isinstance(result, dict) + assert result["model"] == "gemini/gemini-2.0-flash" def test_provider_prefix_not_doubled_when_already_present(self) -> None: """Test that provider prefix is not added when model already has a slash.""" from crewai.flow.human_feedback import _serialize_llm_for_context + from crewai.llm import LLM - mock_obj = MagicMock() - mock_obj.model = "gemini/gemini-2.0-flash" - mock_obj.provider = "gemini" - assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-2.0-flash" + llm = LLM(model="gemini/gemini-2.0-flash") + result = _serialize_llm_for_context(llm) + assert isinstance(result, dict) + assert result["model"] == "gemini/gemini-2.0-flash" def test_no_provider_attr_falls_back_to_bare_model(self) -> None: - """Test that bare model is used when no provider attribute exists.""" + """Test that objects without to_config_dict fall back to model string.""" from crewai.flow.human_feedback import _serialize_llm_for_context mock_obj = MagicMock(spec=[]) @@ -1402,9 +1407,11 @@ class TestLiveLLMPreservationOnResume: with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm): flow.resume("looks good!") - # Should fall back to the serialized string + # Should fall back to deserialized LLM from context string assert len(captured_llm) == 1 - assert captured_llm[0] == "gpt-4o-mini" + from crewai.llms.base_llm import BaseLLM as BaseLLMClass + assert isinstance(captured_llm[0], BaseLLMClass) + assert captured_llm[0].model == "gpt-4o-mini" @patch("crewai.flow.flow.crewai_event_bus.emit") def test_resume_async_uses_string_from_context_when_hf_llm_is_string( @@ -1461,9 +1468,11 @@ class TestLiveLLMPreservationOnResume: with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm): flow.resume("looks good!") - # Should use context.llm since _hf_llm is a string (not BaseLLM) + # _hf_llm is a string, so resume deserializes context.llm into an LLM instance assert len(captured_llm) == 1 - assert captured_llm[0] == "gpt-4o-mini" + from crewai.llms.base_llm import BaseLLM as BaseLLMClass + assert isinstance(captured_llm[0], BaseLLMClass) + assert captured_llm[0].model == "gpt-4o-mini" def test_hf_llm_set_for_async_wrapper(self) -> None: """Test that _hf_llm is set on async wrapper functions."""