better serialization for human feedback in flow with models defined as dicts

This commit is contained in:
lorenzejay
2026-03-23 11:37:56 -07:00
parent c92de53da7
commit ee5b9bb479
10 changed files with 374 additions and 21 deletions

View File

@@ -60,7 +60,7 @@ class PendingFeedbackContext:
emit: list[str] | None = None
default_outcome: str | None = None
metadata: dict[str, Any] = field(default_factory=dict)
llm: str | None = None
llm: dict[str, Any] | str | None = None
requested_at: datetime = field(default_factory=datetime.now)
def to_dict(self) -> dict[str, Any]:

View File

@@ -1316,25 +1316,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
emit = context.emit
default_outcome = context.default_outcome
# Try to get the live LLM from the re-imported decorator instead of the
# serialized string. When a flow pauses for HITL and resumes (possibly in
# a different process), context.llm only contains a model string like
# 'gemini/gemini-3-flash-preview'. This loses credentials, project,
# location, safety_settings, and client_params. By looking up the method
# on the re-imported flow class, we can retrieve the fully-configured LLM
# that was passed to the @human_feedback decorator.
llm = context.llm # fallback to serialized string
# Try to get the live LLM from the re-imported decorator first.
# This preserves the fully-configured object (credentials, safety_settings, etc.)
# for same-process resume. For cross-process resume, fall back to the
# serialized context.llm which is now a dict with full config (or a legacy string).
from crewai.flow.human_feedback import _deserialize_llm_from_context
llm = None
method = self._methods.get(FlowMethodName(context.method_name))
if method is not None:
live_llm = getattr(method, "_hf_llm", None)
if live_llm is not None:
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
# Only use live LLM if it's a BaseLLM instance (not a string)
# String values offer no benefit over the serialized context.llm
if isinstance(live_llm, BaseLLMClass):
llm = live_llm
if llm is None:
llm = _deserialize_llm_from_context(context.llm)
# Determine outcome
collapsed_outcome: str | None = None

View File

@@ -76,22 +76,47 @@ if TYPE_CHECKING:
F = TypeVar("F", bound=Callable[..., Any])
def _serialize_llm_for_context(llm: Any) -> str | None:
"""Serialize a BaseLLM object to a model string with provider prefix.
def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
"""Serialize a BaseLLM object to a dict preserving full config.
When persisting the LLM for HITL resume, we need to store enough info
to reconstruct a working LLM on the resume worker. Just storing the bare
model name (e.g. "gemini-3-flash-preview") causes provider inference to
fail — it defaults to OpenAI. Including the provider prefix (e.g.
"gemini/gemini-3-flash-preview") allows LLM() to correctly route.
Delegates to ``llm.to_config_dict()`` when available (BaseLLM and
subclasses). Falls back to extracting the model string with provider
prefix for unknown LLM types.
"""
if hasattr(llm, "to_config_dict"):
return llm.to_config_dict()
# Fallback for non-BaseLLM objects: just extract model + provider prefix
model = getattr(llm, "model", None)
if not model:
return None
provider = getattr(llm, "provider", None)
if provider and "/" not in model:
return f"{provider}/{model}"
return model
return f"{provider}/{model}" if provider and "/" not in model else model
def _deserialize_llm_from_context(llm_data: dict[str, Any] | str | None) -> BaseLLM | None:
"""Reconstruct an LLM instance from serialized context data.
Handles both the new dict format (with full config) and the legacy
string format (model name only) for backward compatibility.
Returns a BaseLLM instance, or None if llm_data is None.
"""
if llm_data is None:
return None
from crewai.llm import LLM
if isinstance(llm_data, str):
return LLM(model=llm_data)
if isinstance(llm_data, dict):
model = llm_data.pop("model", None)
if not model:
return None
return LLM(model=model, **llm_data)
return None
@dataclass

View File

@@ -152,6 +152,28 @@ class BaseLLM(ABC):
"cached_prompt_tokens": 0,
}
def to_config_dict(self) -> dict[str, Any]:
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
Returns the core fields that BaseLLM owns. Provider subclasses should
override this (calling ``super().to_config_dict()``) to add their own
fields (e.g. ``project``, ``location``, ``safety_settings``).
"""
model = self.model
provider = self.provider
model_str = f"{provider}/{model}" if provider and "/" not in model else model
config: dict[str, Any] = {"model": model_str}
if self.temperature is not None:
config["temperature"] = self.temperature
if self.base_url is not None:
config["base_url"] = self.base_url
if self.stop:
config["stop"] = self.stop
return config
@property
def provider(self) -> str:
"""Get the provider of the LLM."""

View File

@@ -256,6 +256,19 @@ class AnthropicCompletion(BaseLLM):
else:
self.stop_sequences = []
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Anthropic-specific fields."""
config = super().to_config_dict()
if self.max_tokens != 4096: # non-default
config["max_tokens"] = self.max_tokens
if self.max_retries != 2: # non-default
config["max_retries"] = self.max_retries
if self.top_p is not None:
config["top_p"] = self.top_p
if self.timeout is not None:
config["timeout"] = self.timeout
return config
def _get_client_params(self) -> dict[str, Any]:
"""Get client parameters."""

View File

@@ -180,6 +180,27 @@ class AzureCompletion(BaseLLM):
and "/openai/deployments/" in self.endpoint
)
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Azure-specific fields."""
config = super().to_config_dict()
if self.endpoint:
config["endpoint"] = self.endpoint
if self.api_version and self.api_version != "2024-06-01":
config["api_version"] = self.api_version
if self.timeout is not None:
config["timeout"] = self.timeout
if self.max_retries != 2:
config["max_retries"] = self.max_retries
if self.top_p is not None:
config["top_p"] = self.top_p
if self.frequency_penalty is not None:
config["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
config["presence_penalty"] = self.presence_penalty
if self.max_tokens is not None:
config["max_tokens"] = self.max_tokens
return config
@staticmethod
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.

View File

@@ -346,6 +346,23 @@ class BedrockCompletion(BaseLLM):
# Handle inference profiles for newer models
self.model_id = model
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Bedrock-specific fields."""
config = super().to_config_dict()
# NOTE: AWS credentials (access_key, secret_key, session_token) are
# intentionally excluded — they must come from env on resume.
if self.region_name and self.region_name != "us-east-1":
config["region_name"] = self.region_name
if self.max_tokens is not None:
config["max_tokens"] = self.max_tokens
if self.top_p is not None:
config["top_p"] = self.top_p
if self.top_k is not None:
config["top_k"] = self.top_k
if self.guardrail_config:
config["guardrail_config"] = self.guardrail_config
return config
@property
def stop(self) -> list[str]:
"""Get stop sequences sent to the API."""

View File

@@ -176,6 +176,31 @@ class GeminiCompletion(BaseLLM):
else:
self.stop_sequences = []
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with Gemini/Vertex-specific fields."""
config = super().to_config_dict()
if self.project:
config["project"] = self.project
if self.location and self.location != "us-central1":
config["location"] = self.location
if self.top_p is not None:
config["top_p"] = self.top_p
if self.top_k is not None:
config["top_k"] = self.top_k
if self.max_output_tokens is not None:
config["max_output_tokens"] = self.max_output_tokens
if self.safety_settings:
try:
config["safety_settings"] = [
{"category": str(s.category), "threshold": str(s.threshold)}
if hasattr(s, "category") and hasattr(s, "threshold")
else s
for s in self.safety_settings
]
except Exception:
pass
return config
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
"""Initialize the Google Gen AI client with proper parameter handling.

View File

@@ -329,6 +329,35 @@ class OpenAICompletion(BaseLLM):
"""
self._last_reasoning_items = None
def to_config_dict(self) -> dict[str, Any]:
"""Extend base config with OpenAI-specific fields."""
config = super().to_config_dict()
# Client-level params (from OpenAI SDK)
if self.organization:
config["organization"] = self.organization
if self.project:
config["project"] = self.project
if self.timeout is not None:
config["timeout"] = self.timeout
if self.max_retries != 2:
config["max_retries"] = self.max_retries
# Completion params
if self.top_p is not None:
config["top_p"] = self.top_p
if self.frequency_penalty is not None:
config["frequency_penalty"] = self.frequency_penalty
if self.presence_penalty is not None:
config["presence_penalty"] = self.presence_penalty
if self.max_tokens is not None:
config["max_tokens"] = self.max_tokens
if self.max_completion_tokens is not None:
config["max_completion_tokens"] = self.max_completion_tokens
if self.seed is not None:
config["seed"] = self.seed
if self.reasoning_effort is not None:
config["reasoning_effort"] = self.reasoning_effort
return config
def _get_client_params(self) -> dict[str, Any]:
"""Get OpenAI client parameters."""

View File

@@ -772,3 +772,204 @@ class TestEdgeCases:
assert result.output == "content"
assert result.feedback == "feedback"
assert result.outcome is None # No routing, no outcome
class TestLLMConfigPreservation:
"""Tests that LLM config is preserved through @human_feedback serialization.
PR #4970 introduced _hf_llm stashing so the live LLM object survives
decorator wrapping for same-process resume. The serialization path
(_serialize_llm_for_context / _deserialize_llm_from_context) preserves
config for cross-process resume.
"""
def test_hf_llm_stashed_on_wrapper_with_llm_instance(self):
"""Test that passing an LLM instance stashes it on the wrapper as _hf_llm."""
from crewai.llm import LLM
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
class ConfigFlow(Flow):
@start()
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm=llm_instance,
)
def review(self):
return "content"
method = ConfigFlow.review
assert hasattr(method, "_hf_llm"), "_hf_llm not found on wrapper"
assert method._hf_llm is llm_instance, "_hf_llm is not the same object"
def test_hf_llm_preserved_on_listen_method(self):
"""Test that _hf_llm is preserved when @human_feedback is on a @listen method."""
from crewai.llm import LLM
llm_instance = LLM(model="gpt-4o-mini", temperature=0.7)
class ListenConfigFlow(Flow):
@start()
def generate(self):
return "draft"
@listen("generate")
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm=llm_instance,
)
def review(self):
return "content"
method = ListenConfigFlow.review
assert hasattr(method, "_hf_llm")
assert method._hf_llm is llm_instance
def test_hf_llm_accessible_on_instance(self):
"""Test that _hf_llm survives Flow instantiation (bound method access)."""
from crewai.llm import LLM
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
class InstanceFlow(Flow):
@start()
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm=llm_instance,
)
def review(self):
return "content"
flow = InstanceFlow()
instance_method = flow.review
assert hasattr(instance_method, "_hf_llm")
assert instance_method._hf_llm is llm_instance
def test_serialize_llm_preserves_config_fields(self):
"""Test that _serialize_llm_for_context captures temperature, base_url, etc."""
from crewai.flow.human_feedback import _serialize_llm_for_context
from crewai.llm import LLM
llm = LLM(
model="gpt-4o-mini",
temperature=0.42,
base_url="https://custom.example.com/v1",
)
serialized = _serialize_llm_for_context(llm)
assert isinstance(serialized, dict), f"Expected dict, got {type(serialized)}"
assert serialized["model"] == "openai/gpt-4o-mini"
assert serialized["temperature"] == 0.42
assert serialized["base_url"] == "https://custom.example.com/v1"
def test_serialize_llm_excludes_api_key(self):
"""Test that api_key is NOT included in serialized output (security)."""
from crewai.flow.human_feedback import _serialize_llm_for_context
from crewai.llm import LLM
llm = LLM(model="gpt-4o-mini")
serialized = _serialize_llm_for_context(llm)
assert isinstance(serialized, dict)
assert "api_key" not in serialized
def test_deserialize_round_trip_preserves_config(self):
"""Test that serialize → deserialize round-trip preserves all config."""
from crewai.flow.human_feedback import (
_deserialize_llm_from_context,
_serialize_llm_for_context,
)
from crewai.llm import LLM
original = LLM(
model="gpt-4o-mini",
temperature=0.42,
base_url="https://custom.example.com/v1",
)
serialized = _serialize_llm_for_context(original)
reconstructed = _deserialize_llm_from_context(serialized)
assert reconstructed is not None
assert reconstructed.model == original.model
assert reconstructed.temperature == original.temperature
assert reconstructed.base_url == original.base_url
def test_deserialize_handles_legacy_string_format(self):
"""Test backward compat: plain string still reconstructs an LLM."""
from crewai.flow.human_feedback import _deserialize_llm_from_context
reconstructed = _deserialize_llm_from_context("openai/gpt-4o-mini")
assert reconstructed is not None
assert reconstructed.model == "gpt-4o-mini"
def test_deserialize_returns_none_for_none(self):
"""Test that None input returns None."""
from crewai.flow.human_feedback import _deserialize_llm_from_context
assert _deserialize_llm_from_context(None) is None
def test_serialize_llm_preserves_provider_specific_fields(self):
"""Test that provider-specific fields like project/location are serialized."""
from crewai.flow.human_feedback import _serialize_llm_for_context
from crewai.llm import LLM
# Create a Gemini-style LLM with project and non-default location
llm = LLM(
model="gemini-2.0-flash",
provider="gemini",
project="my-project",
location="europe-west1",
temperature=0.3,
)
serialized = _serialize_llm_for_context(llm)
assert isinstance(serialized, dict)
assert serialized.get("project") == "my-project"
assert serialized.get("location") == "europe-west1"
assert serialized.get("temperature") == 0.3
def test_config_preserved_through_full_flow_execution(self):
"""Test that the LLM with custom config is used during outcome collapsing."""
from crewai.llm import LLM
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
collapse_calls = []
class FullFlow(Flow):
@start()
@human_feedback(
message="Review:",
emit=["approved", "rejected"],
llm=llm_instance,
)
def review(self):
return "content"
@listen("approved")
def on_approved(self):
return "done"
flow = FullFlow()
original_collapse = flow._collapse_to_outcome
def spy_collapse(feedback, outcomes, llm):
collapse_calls.append(llm)
return "approved"
with (
patch.object(flow, "_request_human_feedback", return_value="looks good"),
patch.object(flow, "_collapse_to_outcome", side_effect=spy_collapse),
):
flow.kickoff()
assert len(collapse_calls) == 1
# The LLM passed to _collapse_to_outcome should be the original instance
assert collapse_calls[0] is llm_instance