mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 14:52:36 +00:00
better serialization for human feedback in flow with models defined a… (#5029)
* better serialization for human feedback in flow with models defined as dicts * linted * linted * fix and adjust tests
This commit is contained in:
@@ -60,7 +60,7 @@ class PendingFeedbackContext:
|
||||
emit: list[str] | None = None
|
||||
default_outcome: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
llm: str | None = None
|
||||
llm: dict[str, Any] | str | None = None
|
||||
requested_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@@ -1316,25 +1316,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
emit = context.emit
|
||||
default_outcome = context.default_outcome
|
||||
|
||||
# Try to get the live LLM from the re-imported decorator instead of the
|
||||
# serialized string. When a flow pauses for HITL and resumes (possibly in
|
||||
# a different process), context.llm only contains a model string like
|
||||
# 'gemini/gemini-3-flash-preview'. This loses credentials, project,
|
||||
# location, safety_settings, and client_params. By looking up the method
|
||||
# on the re-imported flow class, we can retrieve the fully-configured LLM
|
||||
# that was passed to the @human_feedback decorator.
|
||||
llm = context.llm # fallback to serialized string
|
||||
# Try to get the live LLM from the re-imported decorator first.
|
||||
# This preserves the fully-configured object (credentials, safety_settings, etc.)
|
||||
# for same-process resume. For cross-process resume, fall back to the
|
||||
# serialized context.llm which is now a dict with full config (or a legacy string).
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
llm = None
|
||||
method = self._methods.get(FlowMethodName(context.method_name))
|
||||
if method is not None:
|
||||
live_llm = getattr(method, "_hf_llm", None)
|
||||
if live_llm is not None:
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
|
||||
# Only use live LLM if it's a BaseLLM instance (not a string)
|
||||
# String values offer no benefit over the serialized context.llm
|
||||
if isinstance(live_llm, BaseLLMClass):
|
||||
llm = live_llm
|
||||
|
||||
if llm is None:
|
||||
llm = _deserialize_llm_from_context(context.llm)
|
||||
|
||||
# Determine outcome
|
||||
collapsed_outcome: str | None = None
|
||||
|
||||
|
||||
@@ -76,22 +76,48 @@ if TYPE_CHECKING:
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _serialize_llm_for_context(llm: Any) -> str | None:
|
||||
"""Serialize a BaseLLM object to a model string with provider prefix.
|
||||
def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
|
||||
"""Serialize a BaseLLM object to a dict preserving full config.
|
||||
|
||||
When persisting the LLM for HITL resume, we need to store enough info
|
||||
to reconstruct a working LLM on the resume worker. Just storing the bare
|
||||
model name (e.g. "gemini-3-flash-preview") causes provider inference to
|
||||
fail — it defaults to OpenAI. Including the provider prefix (e.g.
|
||||
"gemini/gemini-3-flash-preview") allows LLM() to correctly route.
|
||||
Delegates to ``llm.to_config_dict()`` when available (BaseLLM and
|
||||
subclasses). Falls back to extracting the model string with provider
|
||||
prefix for unknown LLM types.
|
||||
"""
|
||||
if hasattr(llm, "to_config_dict"):
|
||||
return llm.to_config_dict()
|
||||
|
||||
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
||||
model = getattr(llm, "model", None)
|
||||
if not model:
|
||||
return None
|
||||
provider = getattr(llm, "provider", None)
|
||||
if provider and "/" not in model:
|
||||
return f"{provider}/{model}"
|
||||
return model
|
||||
return f"{provider}/{model}" if provider and "/" not in model else model
|
||||
|
||||
|
||||
def _deserialize_llm_from_context(
|
||||
llm_data: dict[str, Any] | str | None,
|
||||
) -> BaseLLM | None:
|
||||
"""Reconstruct an LLM instance from serialized context data.
|
||||
|
||||
Handles both the new dict format (with full config) and the legacy
|
||||
string format (model name only) for backward compatibility.
|
||||
|
||||
Returns a BaseLLM instance, or None if llm_data is None.
|
||||
"""
|
||||
if llm_data is None:
|
||||
return None
|
||||
|
||||
from crewai.llm import LLM
|
||||
|
||||
if isinstance(llm_data, str):
|
||||
return LLM(model=llm_data)
|
||||
|
||||
if isinstance(llm_data, dict):
|
||||
model = llm_data.pop("model", None)
|
||||
if not model:
|
||||
return None
|
||||
return LLM(model=model, **llm_data)
|
||||
return None
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -152,6 +152,28 @@ class BaseLLM(ABC):
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
|
||||
|
||||
Returns the core fields that BaseLLM owns. Provider subclasses should
|
||||
override this (calling ``super().to_config_dict()``) to add their own
|
||||
fields (e.g. ``project``, ``location``, ``safety_settings``).
|
||||
"""
|
||||
model = self.model
|
||||
provider = self.provider
|
||||
model_str = f"{provider}/{model}" if provider and "/" not in model else model
|
||||
|
||||
config: dict[str, Any] = {"model": model_str}
|
||||
|
||||
if self.temperature is not None:
|
||||
config["temperature"] = self.temperature
|
||||
if self.base_url is not None:
|
||||
config["base_url"] = self.base_url
|
||||
if self.stop:
|
||||
config["stop"] = self.stop
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
|
||||
@@ -256,6 +256,19 @@ class AnthropicCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.max_tokens != 4096: # non-default
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.max_retries != 2: # non-default
|
||||
config["max_retries"] = self.max_retries
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
return config
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
|
||||
@@ -180,6 +180,27 @@ class AzureCompletion(BaseLLM):
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.endpoint:
|
||||
config["endpoint"] = self.endpoint
|
||||
if self.api_version and self.api_version != "2024-06-01":
|
||||
config["api_version"] = self.api_version
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
if self.max_retries != 2:
|
||||
config["max_retries"] = self.max_retries
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
config["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
config["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
@@ -346,6 +346,23 @@ class BedrockCompletion(BaseLLM):
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# NOTE: AWS credentials (access_key, secret_key, session_token) are
|
||||
# intentionally excluded — they must come from env on resume.
|
||||
if self.region_name and self.region_name != "us-east-1":
|
||||
config["region_name"] = self.region_name
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config["top_k"] = self.top_k
|
||||
if self.guardrail_config:
|
||||
config["guardrail_config"] = self.guardrail_config
|
||||
return config
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
@@ -1880,7 +1897,9 @@ class BedrockCompletion(BaseLLM):
|
||||
# Anthropic (Claude) models reject assistant-last messages when
|
||||
# tools are in the request. Append a user message so the
|
||||
# Converse API accepts the payload.
|
||||
elif "anthropic" in self.model.lower() or "claude" in self.model.lower():
|
||||
elif (
|
||||
"anthropic" in self.model.lower() or "claude" in self.model.lower()
|
||||
):
|
||||
converse_messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
|
||||
@@ -176,6 +176,28 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.project:
|
||||
config["project"] = self.project
|
||||
if self.location and self.location != "us-central1":
|
||||
config["location"] = self.location
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config["max_output_tokens"] = self.max_output_tokens
|
||||
if self.safety_settings:
|
||||
config["safety_settings"] = [
|
||||
{"category": str(s.category), "threshold": str(s.threshold)}
|
||||
if hasattr(s, "category") and hasattr(s, "threshold")
|
||||
else s
|
||||
for s in self.safety_settings
|
||||
]
|
||||
return config
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
|
||||
@@ -329,6 +329,35 @@ class OpenAICompletion(BaseLLM):
|
||||
"""
|
||||
self._last_reasoning_items = None
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with OpenAI-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# Client-level params (from OpenAI SDK)
|
||||
if self.organization:
|
||||
config["organization"] = self.organization
|
||||
if self.project:
|
||||
config["project"] = self.project
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
if self.max_retries != 2:
|
||||
config["max_retries"] = self.max_retries
|
||||
# Completion params
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
config["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
config["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.max_completion_tokens is not None:
|
||||
config["max_completion_tokens"] = self.max_completion_tokens
|
||||
if self.seed is not None:
|
||||
config["seed"] = self.seed
|
||||
if self.reasoning_effort is not None:
|
||||
config["reasoning_effort"] = self.reasoning_effort
|
||||
return config
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
|
||||
@@ -988,11 +988,9 @@ class TestLLMObjectPreservedInContext:
|
||||
db_path = os.path.join(tmpdir, "test_flows.db")
|
||||
persistence = SQLiteFlowPersistence(db_path)
|
||||
|
||||
# Create a mock BaseLLM object (not a string)
|
||||
# Simulates LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
mock_llm_obj = MagicMock()
|
||||
mock_llm_obj.model = "gemini-2.0-flash"
|
||||
mock_llm_obj.provider = "gemini"
|
||||
# Create a real LLM object (not a string)
|
||||
from crewai.llm import LLM
|
||||
mock_llm_obj = LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
|
||||
class PausingProvider:
|
||||
def __init__(self, persistence: SQLiteFlowPersistence):
|
||||
@@ -1041,32 +1039,37 @@ class TestLLMObjectPreservedInContext:
|
||||
result = flow1.kickoff()
|
||||
assert isinstance(result, HumanFeedbackPending)
|
||||
|
||||
# Verify the context stored the model STRING, not None
|
||||
# Verify the context stored the model config dict, not None
|
||||
assert provider.captured_context is not None
|
||||
assert provider.captured_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(provider.captured_context.llm, dict)
|
||||
assert provider.captured_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Verify it survives persistence roundtrip
|
||||
flow_id = result.context.flow_id
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
assert loaded is not None
|
||||
_, loaded_context = loaded
|
||||
assert loaded_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(loaded_context.llm, dict)
|
||||
assert loaded_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Phase 2: Resume with positive feedback - should use LLM to classify
|
||||
flow2 = TestFlow.from_pending(flow_id, persistence)
|
||||
assert flow2._pending_feedback_context is not None
|
||||
assert flow2._pending_feedback_context.llm == "gemini/gemini-2.0-flash"
|
||||
assert isinstance(flow2._pending_feedback_context.llm, dict)
|
||||
assert flow2._pending_feedback_context.llm["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
# Mock _collapse_to_outcome to verify it gets called (not skipped)
|
||||
with patch.object(flow2, "_collapse_to_outcome", return_value="approved") as mock_collapse:
|
||||
flow2.resume("this looks good, proceed!")
|
||||
|
||||
# The key assertion: _collapse_to_outcome was called (not skipped due to llm=None)
|
||||
mock_collapse.assert_called_once_with(
|
||||
feedback="this looks good, proceed!",
|
||||
outcomes=["needs_changes", "approved"],
|
||||
llm="gemini/gemini-2.0-flash",
|
||||
)
|
||||
mock_collapse.assert_called_once()
|
||||
call_kwargs = mock_collapse.call_args
|
||||
assert call_kwargs.kwargs["feedback"] == "this looks good, proceed!"
|
||||
assert call_kwargs.kwargs["outcomes"] == ["needs_changes", "approved"]
|
||||
# LLM should be a live object (from _hf_llm) or reconstructed, not None
|
||||
assert call_kwargs.kwargs["llm"] is not None
|
||||
assert getattr(call_kwargs.kwargs["llm"], "model", None) == "gemini-2.0-flash"
|
||||
assert flow2.last_human_feedback.outcome == "approved"
|
||||
assert flow2.result_path == "approved"
|
||||
|
||||
@@ -1096,23 +1099,25 @@ class TestLLMObjectPreservedInContext:
|
||||
def test_provider_prefix_added_to_bare_model(self) -> None:
|
||||
"""Test that provider prefix is added when model has no slash."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.model = "gemini-3-flash-preview"
|
||||
mock_obj.provider = "gemini"
|
||||
assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-3-flash-preview"
|
||||
llm = LLM(model="gemini-2.0-flash", provider="gemini")
|
||||
result = _serialize_llm_for_context(llm)
|
||||
assert isinstance(result, dict)
|
||||
assert result["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
def test_provider_prefix_not_doubled_when_already_present(self) -> None:
|
||||
"""Test that provider prefix is not added when model already has a slash."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
mock_obj = MagicMock()
|
||||
mock_obj.model = "gemini/gemini-2.0-flash"
|
||||
mock_obj.provider = "gemini"
|
||||
assert _serialize_llm_for_context(mock_obj) == "gemini/gemini-2.0-flash"
|
||||
llm = LLM(model="gemini/gemini-2.0-flash")
|
||||
result = _serialize_llm_for_context(llm)
|
||||
assert isinstance(result, dict)
|
||||
assert result["model"] == "gemini/gemini-2.0-flash"
|
||||
|
||||
def test_no_provider_attr_falls_back_to_bare_model(self) -> None:
|
||||
"""Test that bare model is used when no provider attribute exists."""
|
||||
"""Test that objects without to_config_dict fall back to model string."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
|
||||
mock_obj = MagicMock(spec=[])
|
||||
@@ -1402,9 +1407,11 @@ class TestLiveLLMPreservationOnResume:
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# Should fall back to the serialized string
|
||||
# Should fall back to deserialized LLM from context string
|
||||
assert len(captured_llm) == 1
|
||||
assert captured_llm[0] == "gpt-4o-mini"
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
@patch("crewai.flow.flow.crewai_event_bus.emit")
|
||||
def test_resume_async_uses_string_from_context_when_hf_llm_is_string(
|
||||
@@ -1461,9 +1468,11 @@ class TestLiveLLMPreservationOnResume:
|
||||
with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm):
|
||||
flow.resume("looks good!")
|
||||
|
||||
# Should use context.llm since _hf_llm is a string (not BaseLLM)
|
||||
# _hf_llm is a string, so resume deserializes context.llm into an LLM instance
|
||||
assert len(captured_llm) == 1
|
||||
assert captured_llm[0] == "gpt-4o-mini"
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
assert isinstance(captured_llm[0], BaseLLMClass)
|
||||
assert captured_llm[0].model == "gpt-4o-mini"
|
||||
|
||||
def test_hf_llm_set_for_async_wrapper(self) -> None:
|
||||
"""Test that _hf_llm is set on async wrapper functions."""
|
||||
|
||||
@@ -772,3 +772,204 @@ class TestEdgeCases:
|
||||
assert result.output == "content"
|
||||
assert result.feedback == "feedback"
|
||||
assert result.outcome is None # No routing, no outcome
|
||||
|
||||
|
||||
class TestLLMConfigPreservation:
|
||||
"""Tests that LLM config is preserved through @human_feedback serialization.
|
||||
|
||||
PR #4970 introduced _hf_llm stashing so the live LLM object survives
|
||||
decorator wrapping for same-process resume. The serialization path
|
||||
(_serialize_llm_for_context / _deserialize_llm_from_context) preserves
|
||||
config for cross-process resume.
|
||||
"""
|
||||
|
||||
def test_hf_llm_stashed_on_wrapper_with_llm_instance(self):
|
||||
"""Test that passing an LLM instance stashes it on the wrapper as _hf_llm."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class ConfigFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ConfigFlow.review
|
||||
assert hasattr(method, "_hf_llm"), "_hf_llm not found on wrapper"
|
||||
assert method._hf_llm is llm_instance, "_hf_llm is not the same object"
|
||||
|
||||
def test_hf_llm_preserved_on_listen_method(self):
|
||||
"""Test that _hf_llm is preserved when @human_feedback is on a @listen method."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.7)
|
||||
|
||||
class ListenConfigFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
return "draft"
|
||||
|
||||
@listen("generate")
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ListenConfigFlow.review
|
||||
assert hasattr(method, "_hf_llm")
|
||||
assert method._hf_llm is llm_instance
|
||||
|
||||
def test_hf_llm_accessible_on_instance(self):
|
||||
"""Test that _hf_llm survives Flow instantiation (bound method access)."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class InstanceFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
flow = InstanceFlow()
|
||||
instance_method = flow.review
|
||||
assert hasattr(instance_method, "_hf_llm")
|
||||
assert instance_method._hf_llm is llm_instance
|
||||
|
||||
def test_serialize_llm_preserves_config_fields(self):
|
||||
"""Test that _serialize_llm_for_context captures temperature, base_url, etc."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.42,
|
||||
base_url="https://custom.example.com/v1",
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
|
||||
assert isinstance(serialized, dict), f"Expected dict, got {type(serialized)}"
|
||||
assert serialized["model"] == "openai/gpt-4o-mini"
|
||||
assert serialized["temperature"] == 0.42
|
||||
assert serialized["base_url"] == "https://custom.example.com/v1"
|
||||
|
||||
def test_serialize_llm_excludes_api_key(self):
|
||||
"""Test that api_key is NOT included in serialized output (security)."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
assert isinstance(serialized, dict)
|
||||
assert "api_key" not in serialized
|
||||
|
||||
def test_deserialize_round_trip_preserves_config(self):
|
||||
"""Test that serialize → deserialize round-trip preserves all config."""
|
||||
from crewai.flow.human_feedback import (
|
||||
_deserialize_llm_from_context,
|
||||
_serialize_llm_for_context,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
|
||||
original = LLM(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.42,
|
||||
base_url="https://custom.example.com/v1",
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(original)
|
||||
reconstructed = _deserialize_llm_from_context(serialized)
|
||||
|
||||
assert reconstructed is not None
|
||||
assert reconstructed.model == original.model
|
||||
assert reconstructed.temperature == original.temperature
|
||||
assert reconstructed.base_url == original.base_url
|
||||
|
||||
def test_deserialize_handles_legacy_string_format(self):
|
||||
"""Test backward compat: plain string still reconstructs an LLM."""
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
reconstructed = _deserialize_llm_from_context("openai/gpt-4o-mini")
|
||||
|
||||
assert reconstructed is not None
|
||||
assert reconstructed.model == "gpt-4o-mini"
|
||||
|
||||
def test_deserialize_returns_none_for_none(self):
|
||||
"""Test that None input returns None."""
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
assert _deserialize_llm_from_context(None) is None
|
||||
|
||||
def test_serialize_llm_preserves_provider_specific_fields(self):
|
||||
"""Test that provider-specific fields like project/location are serialized."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
# Create a Gemini-style LLM with project and non-default location
|
||||
llm = LLM(
|
||||
model="gemini-2.0-flash",
|
||||
provider="gemini",
|
||||
project="my-project",
|
||||
location="europe-west1",
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
|
||||
assert isinstance(serialized, dict)
|
||||
assert serialized.get("project") == "my-project"
|
||||
assert serialized.get("location") == "europe-west1"
|
||||
assert serialized.get("temperature") == 0.3
|
||||
|
||||
def test_config_preserved_through_full_flow_execution(self):
|
||||
"""Test that the LLM with custom config is used during outcome collapsing."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
collapse_calls = []
|
||||
|
||||
class FullFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "done"
|
||||
|
||||
flow = FullFlow()
|
||||
|
||||
original_collapse = flow._collapse_to_outcome
|
||||
|
||||
def spy_collapse(feedback, outcomes, llm):
|
||||
collapse_calls.append(llm)
|
||||
return "approved"
|
||||
|
||||
with (
|
||||
patch.object(flow, "_request_human_feedback", return_value="looks good"),
|
||||
patch.object(flow, "_collapse_to_outcome", side_effect=spy_collapse),
|
||||
):
|
||||
flow.kickoff()
|
||||
|
||||
assert len(collapse_calls) == 1
|
||||
# The LLM passed to _collapse_to_outcome should be the original instance
|
||||
assert collapse_calls[0] is llm_instance
|
||||
|
||||
Reference in New Issue
Block a user