mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-27 05:12:52 +00:00
fix and adjust tests
This commit is contained in:
@@ -988,11 +988,9 @@ class TestLLMObjectPreservedInContext:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
# Create a mock BaseLLM object (not a string)
|
||||
# Simulates LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj.model = "gemini-2.0-flash"
|
||||
mock_llm_obj.provider = "gemini"
|
||||
# Create a real LLM object (not a string)
|
||||
from crewai.llm import LLM
|
||||
mock_llm_obj = LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
|
||||
class PausingProvider:
|
||||
def __init__(self, persistence: SQLiteFlowPersistence):
|
||||
@@ -1041,32 +1039,37 @@ class TestLLMObjectPreservedInContext:
|
||||
result = flow1.kickoff()
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
|
||||
# Verify the context stored the model STRING, not None
|
||||
# Verify the context stored the model config dict, not None
|
||||
assert provider.captured_context is not None
|
||||
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(provider.captured_context.llm, dict)
|
||||
assert provider.captured_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Verify it survives persistence roundtrip
|
||||
flow_id = result.context.flow_id
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
assert loaded is not None
|
||||
_, loaded_context = loaded
|
||||
assert loaded_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(loaded_context.llm, dict)
|
||||
assert loaded_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Phase 2: Resume with positive feedback - should use LLM to classify
|
||||
flow2 = TestFlow.from_pending(flow_id, persistence)
|
||||
assert flow2._pending_feedback_context is not None
|
||||
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(flow2._pending_feedback_context.llm, dict)
|
||||
assert flow2._pending_feedback_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Mock _collapse_to_outcome to verify it gets called (not skipped)
|
||||
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
|
||||
flow2.resume("this looks good, proceed!")
|
||||
|
||||
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
|
||||
mock_collapse.assert_called_once_with(
|
||||
feedback="this looks good, proceed!",
|
||||
outcomes=["needs_changes", "approved"],
|
||||
llm="gemini/gemini-2.0-flash",
|
||||
)
|
||||
mock_collapse.assert_called_once()
|
||||
call_kwargs = mock_collapse.call_args
|
||||
assert call_kwargs.kwargs["feedback"] == "this looks good, proceed!"
|
||||
assert call_kwargs.kwargs["outcomes"] == ["needs_changes", "approved"]
|
||||
# LLM should be a live object (from _hf_llm) or reconstructed, not None
|
||||
assert call_kwargs.kwargs["llm"] is not None
|
||||
assert getattr(call_kwargs.kwargs["llm"], "model", None) == "gemini-2.0-flash"
|
||||
assert flow2.last_human_feedback.outcome == "approved"
|
||||
assert flow2.result_path == "approved"
|
||||
|
||||
@@ -1096,23 +1099,25 @@ class TestLLMObjectPreservedInContext:
|
||||
def test_provider_prefix_added_to_bare_model(self) -> None:
|
||||
"""Test that provider prefix is added when model has no slash."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.model = "gemini-3-flash-preview"
|
||||
mock_obj.provider = "gemini"
|
||||
assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-3-flash-preview"
|
||||
llm = LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
result = _serialize_llm_for_context(llm)
|
||||
assert isinstance(result, dict)
|
||||
assert result["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
def test_provider_prefix_not_doubled_when_already_present(self) -> None:
|
||||
"""Test that provider prefix is not added when model already has a slash."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.model = "gemini/gemini-2.0-flash"
|
||||
mock_obj.provider = "gemini"
|
||||
assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-2.0-flash"
|
||||
llm = LLM(model="gemini/gemini-2.0-flash")
|
||||
result = _serialize_llm_for_context(llm)
|
||||
assert isinstance(result, dict)
|
||||
assert result["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
def test_no_provider_attr_falls_back_to_bare_model(self) -> None:
|
||||
"""Test that bare model is used when no provider attribute exists."""
|
||||
"""Test that objects without to_config_dict fall back to model string."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
|
||||
mock_obj = MagicMock(spec=[])
|
||||
@@ -1402,9 +1407,11 @@ class TestLiveLLMPreservationOnResume:
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# Should fall back to the serialized string
|
||||
# Should fall back to deserialized LLM from context string
|
||||
assert len(captured_llm) == 1
|
||||
assert captured_llm[0] == "gpt-4o-mini"
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
@patch("crewai.flow.flow.crewai_event_bus.emit")
|
||||
def test_resume_async_uses_string_from_context_when_hf_llm_is_string(
|
||||
@@ -1461,9 +1468,11 @@ class TestLiveLLMPreservationOnResume:
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# Should use context.llm since _hf_llm is a string (not BaseLLM)
|
||||
# _hf_llm is a string, so resume deserializes context.llm into an LLM instance
|
||||
assert len(captured_llm) == 1
|
||||
assert captured_llm[0] == "gpt-4o-mini"
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
def test_hf_llm_set_for_async_wrapper(self) -> None:
|
||||
"""Test that _hf_llm is set on async wrapper functions."""
|
||||
|
||||
Reference in New Issue
Block a user