mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-30 23:02:50 +00:00
better serialization for human feedback in flow with models defined as dicts
This commit is contained in:
@@ -60,7 +60,7 @@ class PendingFeedbackContext:
|
||||
emit: list[str] | None = None
|
||||
default_outcome: str | None = None
|
||||
metadata: dict[str, Any] = field(default_factory=dict)
|
||||
llm: str | None = None
|
||||
llm: dict[str, Any] | str | None = None
|
||||
requested_at: datetime = field(default_factory=datetime.now)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
|
||||
@@ -1316,25 +1316,25 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
emit = context.emit
|
||||
default_outcome = context.default_outcome
|
||||
|
||||
# Try to get the live LLM from the re-imported decorator instead of the
|
||||
# serialized string. When a flow pauses for HITL and resumes (possibly in
|
||||
# a different process), context.llm only contains a model string like
|
||||
# 'gemini/gemini-3-flash-preview'. This loses credentials, project,
|
||||
# location, safety_settings, and client_params. By looking up the method
|
||||
# on the re-imported flow class, we can retrieve the fully-configured LLM
|
||||
# that was passed to the @human_feedback decorator.
|
||||
llm = context.llm # fallback to serialized string
|
||||
# Try to get the live LLM from the re-imported decorator first.
|
||||
# This preserves the fully-configured object (credentials, safety_settings, etc.)
|
||||
# for same-process resume. For cross-process resume, fall back to the
|
||||
# serialized context.llm which is now a dict with full config (or a legacy string).
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
llm = None
|
||||
method = self._methods.get(FlowMethodName(context.method_name))
|
||||
if method is not None:
|
||||
live_llm = getattr(method, "_hf_llm", None)
|
||||
if live_llm is not None:
|
||||
from crewai.llms.base_llm import BaseLLM as BaseLLMClass
|
||||
|
||||
# Only use live LLM if it's a BaseLLM instance (not a string)
|
||||
# String values offer no benefit over the serialized context.llm
|
||||
if isinstance(live_llm, BaseLLMClass):
|
||||
llm = live_llm
|
||||
|
||||
if llm is None:
|
||||
llm = _deserialize_llm_from_context(context.llm)
|
||||
|
||||
# Determine outcome
|
||||
collapsed_outcome: str | None = None
|
||||
|
||||
|
||||
@@ -76,22 +76,47 @@ if TYPE_CHECKING:
|
||||
F = TypeVar("F", bound=Callable[..., Any])
|
||||
|
||||
|
||||
def _serialize_llm_for_context(llm: Any) -> str | None:
|
||||
"""Serialize a BaseLLM object to a model string with provider prefix.
|
||||
def _serialize_llm_for_context(llm: Any) -> dict[str, Any] | str | None:
|
||||
"""Serialize a BaseLLM object to a dict preserving full config.
|
||||
|
||||
When persisting the LLM for HITL resume, we need to store enough info
|
||||
to reconstruct a working LLM on the resume worker. Just storing the bare
|
||||
model name (e.g. "gemini-3-flash-preview") causes provider inference to
|
||||
fail — it defaults to OpenAI. Including the provider prefix (e.g.
|
||||
"gemini/gemini-3-flash-preview") allows LLM() to correctly route.
|
||||
Delegates to ``llm.to_config_dict()`` when available (BaseLLM and
|
||||
subclasses). Falls back to extracting the model string with provider
|
||||
prefix for unknown LLM types.
|
||||
"""
|
||||
if hasattr(llm, "to_config_dict"):
|
||||
return llm.to_config_dict()
|
||||
|
||||
# Fallback for non-BaseLLM objects: just extract model + provider prefix
|
||||
model = getattr(llm, "model", None)
|
||||
if not model:
|
||||
return None
|
||||
provider = getattr(llm, "provider", None)
|
||||
if provider and "/" not in model:
|
||||
return f"{provider}/{model}"
|
||||
return model
|
||||
return f"{provider}/{model}" if provider and "/" not in model else model
|
||||
|
||||
|
||||
def _deserialize_llm_from_context(llm_data: dict[str, Any] | str | None) -> BaseLLM | None:
|
||||
"""Reconstruct an LLM instance from serialized context data.
|
||||
|
||||
Handles both the new dict format (with full config) and the legacy
|
||||
string format (model name only) for backward compatibility.
|
||||
|
||||
Returns a BaseLLM instance, or None if llm_data is None.
|
||||
"""
|
||||
if llm_data is None:
|
||||
return None
|
||||
|
||||
from crewai.llm import LLM
|
||||
|
||||
if isinstance(llm_data, str):
|
||||
return LLM(model=llm_data)
|
||||
|
||||
if isinstance(llm_data, dict):
|
||||
model = llm_data.pop("model", None)
|
||||
if not model:
|
||||
return None
|
||||
return LLM(model=model, **llm_data)
|
||||
return None
|
||||
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -152,6 +152,28 @@ class BaseLLM(ABC):
|
||||
"cached_prompt_tokens": 0,
|
||||
}
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Serialize this LLM to a dict that can reconstruct it via ``LLM(**config)``.
|
||||
|
||||
Returns the core fields that BaseLLM owns. Provider subclasses should
|
||||
override this (calling ``super().to_config_dict()``) to add their own
|
||||
fields (e.g. ``project``, ``location``, ``safety_settings``).
|
||||
"""
|
||||
model = self.model
|
||||
provider = self.provider
|
||||
model_str = f"{provider}/{model}" if provider and "/" not in model else model
|
||||
|
||||
config: dict[str, Any] = {"model": model_str}
|
||||
|
||||
if self.temperature is not None:
|
||||
config["temperature"] = self.temperature
|
||||
if self.base_url is not None:
|
||||
config["base_url"] = self.base_url
|
||||
if self.stop:
|
||||
config["stop"] = self.stop
|
||||
|
||||
return config
|
||||
|
||||
@property
|
||||
def provider(self) -> str:
|
||||
"""Get the provider of the LLM."""
|
||||
|
||||
@@ -256,6 +256,19 @@ class AnthropicCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Anthropic-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.max_tokens != 4096: # non-default
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.max_retries != 2: # non-default
|
||||
config["max_retries"] = self.max_retries
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
return config
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
|
||||
@@ -180,6 +180,27 @@ class AzureCompletion(BaseLLM):
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Azure-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.endpoint:
|
||||
config["endpoint"] = self.endpoint
|
||||
if self.api_version and self.api_version != "2024-06-01":
|
||||
config["api_version"] = self.api_version
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
if self.max_retries != 2:
|
||||
config["max_retries"] = self.max_retries
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
config["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
config["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
return config
|
||||
|
||||
@staticmethod
|
||||
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
@@ -346,6 +346,23 @@ class BedrockCompletion(BaseLLM):
|
||||
# Handle inference profiles for newer models
|
||||
self.model_id = model
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Bedrock-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# NOTE: AWS credentials (access_key, secret_key, session_token) are
|
||||
# intentionally excluded — they must come from env on resume.
|
||||
if self.region_name and self.region_name != "us-east-1":
|
||||
config["region_name"] = self.region_name
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config["top_k"] = self.top_k
|
||||
if self.guardrail_config:
|
||||
config["guardrail_config"] = self.guardrail_config
|
||||
return config
|
||||
|
||||
@property
|
||||
def stop(self) -> list[str]:
|
||||
"""Get stop sequences sent to the API."""
|
||||
|
||||
@@ -176,6 +176,31 @@ class GeminiCompletion(BaseLLM):
|
||||
else:
|
||||
self.stop_sequences = []
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with Gemini/Vertex-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
if self.project:
|
||||
config["project"] = self.project
|
||||
if self.location and self.location != "us-central1":
|
||||
config["location"] = self.location
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.top_k is not None:
|
||||
config["top_k"] = self.top_k
|
||||
if self.max_output_tokens is not None:
|
||||
config["max_output_tokens"] = self.max_output_tokens
|
||||
if self.safety_settings:
|
||||
try:
|
||||
config["safety_settings"] = [
|
||||
{"category": str(s.category), "threshold": str(s.threshold)}
|
||||
if hasattr(s, "category") and hasattr(s, "threshold")
|
||||
else s
|
||||
for s in self.safety_settings
|
||||
]
|
||||
except Exception:
|
||||
pass
|
||||
return config
|
||||
|
||||
def _initialize_client(self, use_vertexai: bool = False) -> genai.Client:
|
||||
"""Initialize the Google Gen AI client with proper parameter handling.
|
||||
|
||||
|
||||
@@ -329,6 +329,35 @@ class OpenAICompletion(BaseLLM):
|
||||
"""
|
||||
self._last_reasoning_items = None
|
||||
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with OpenAI-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# Client-level params (from OpenAI SDK)
|
||||
if self.organization:
|
||||
config["organization"] = self.organization
|
||||
if self.project:
|
||||
config["project"] = self.project
|
||||
if self.timeout is not None:
|
||||
config["timeout"] = self.timeout
|
||||
if self.max_retries != 2:
|
||||
config["max_retries"] = self.max_retries
|
||||
# Completion params
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
config["frequency_penalty"] = self.frequency_penalty
|
||||
if self.presence_penalty is not None:
|
||||
config["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
config["max_tokens"] = self.max_tokens
|
||||
if self.max_completion_tokens is not None:
|
||||
config["max_completion_tokens"] = self.max_completion_tokens
|
||||
if self.seed is not None:
|
||||
config["seed"] = self.seed
|
||||
if self.reasoning_effort is not None:
|
||||
config["reasoning_effort"] = self.reasoning_effort
|
||||
return config
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get OpenAI client parameters."""
|
||||
|
||||
|
||||
@@ -772,3 +772,204 @@ class TestEdgeCases:
|
||||
assert result.output == "content"
|
||||
assert result.feedback == "feedback"
|
||||
assert result.outcome is None # No routing, no outcome
|
||||
|
||||
|
||||
class TestLLMConfigPreservation:
|
||||
"""Tests that LLM config is preserved through @human_feedback serialization.
|
||||
|
||||
PR #4970 introduced _hf_llm stashing so the live LLM object survives
|
||||
decorator wrapping for same-process resume. The serialization path
|
||||
(_serialize_llm_for_context / _deserialize_llm_from_context) preserves
|
||||
config for cross-process resume.
|
||||
"""
|
||||
|
||||
def test_hf_llm_stashed_on_wrapper_with_llm_instance(self):
|
||||
"""Test that passing an LLM instance stashes it on the wrapper as _hf_llm."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class ConfigFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ConfigFlow.review
|
||||
assert hasattr(method, "_hf_llm"), "_hf_llm not found on wrapper"
|
||||
assert method._hf_llm is llm_instance, "_hf_llm is not the same object"
|
||||
|
||||
def test_hf_llm_preserved_on_listen_method(self):
|
||||
"""Test that _hf_llm is preserved when @human_feedback is on a @listen method."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.7)
|
||||
|
||||
class ListenConfigFlow(Flow):
|
||||
@start()
|
||||
def generate(self):
|
||||
return "draft"
|
||||
|
||||
@listen("generate")
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
method = ListenConfigFlow.review
|
||||
assert hasattr(method, "_hf_llm")
|
||||
assert method._hf_llm is llm_instance
|
||||
|
||||
def test_hf_llm_accessible_on_instance(self):
|
||||
"""Test that _hf_llm survives Flow instantiation (bound method access)."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
|
||||
class InstanceFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
flow = InstanceFlow()
|
||||
instance_method = flow.review
|
||||
assert hasattr(instance_method, "_hf_llm")
|
||||
assert instance_method._hf_llm is llm_instance
|
||||
|
||||
def test_serialize_llm_preserves_config_fields(self):
|
||||
"""Test that _serialize_llm_for_context captures temperature, base_url, etc."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.42,
|
||||
base_url="https://custom.example.com/v1",
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
|
||||
assert isinstance(serialized, dict), f"Expected dict, got {type(serialized)}"
|
||||
assert serialized["model"] == "openai/gpt-4o-mini"
|
||||
assert serialized["temperature"] == 0.42
|
||||
assert serialized["base_url"] == "https://custom.example.com/v1"
|
||||
|
||||
def test_serialize_llm_excludes_api_key(self):
|
||||
"""Test that api_key is NOT included in serialized output (security)."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
assert isinstance(serialized, dict)
|
||||
assert "api_key" not in serialized
|
||||
|
||||
def test_deserialize_round_trip_preserves_config(self):
|
||||
"""Test that serialize → deserialize round-trip preserves all config."""
|
||||
from crewai.flow.human_feedback import (
|
||||
_deserialize_llm_from_context,
|
||||
_serialize_llm_for_context,
|
||||
)
|
||||
from crewai.llm import LLM
|
||||
|
||||
original = LLM(
|
||||
model="gpt-4o-mini",
|
||||
temperature=0.42,
|
||||
base_url="https://custom.example.com/v1",
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(original)
|
||||
reconstructed = _deserialize_llm_from_context(serialized)
|
||||
|
||||
assert reconstructed is not None
|
||||
assert reconstructed.model == original.model
|
||||
assert reconstructed.temperature == original.temperature
|
||||
assert reconstructed.base_url == original.base_url
|
||||
|
||||
def test_deserialize_handles_legacy_string_format(self):
|
||||
"""Test backward compat: plain string still reconstructs an LLM."""
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
reconstructed = _deserialize_llm_from_context("openai/gpt-4o-mini")
|
||||
|
||||
assert reconstructed is not None
|
||||
assert reconstructed.model == "gpt-4o-mini"
|
||||
|
||||
def test_deserialize_returns_none_for_none(self):
|
||||
"""Test that None input returns None."""
|
||||
from crewai.flow.human_feedback import _deserialize_llm_from_context
|
||||
|
||||
assert _deserialize_llm_from_context(None) is None
|
||||
|
||||
def test_serialize_llm_preserves_provider_specific_fields(self):
|
||||
"""Test that provider-specific fields like project/location are serialized."""
|
||||
from crewai.flow.human_feedback import _serialize_llm_for_context
|
||||
from crewai.llm import LLM
|
||||
|
||||
# Create a Gemini-style LLM with project and non-default location
|
||||
llm = LLM(
|
||||
model="gemini-2.0-flash",
|
||||
provider="gemini",
|
||||
project="my-project",
|
||||
location="europe-west1",
|
||||
temperature=0.3,
|
||||
)
|
||||
|
||||
serialized = _serialize_llm_for_context(llm)
|
||||
|
||||
assert isinstance(serialized, dict)
|
||||
assert serialized.get("project") == "my-project"
|
||||
assert serialized.get("location") == "europe-west1"
|
||||
assert serialized.get("temperature") == 0.3
|
||||
|
||||
def test_config_preserved_through_full_flow_execution(self):
|
||||
"""Test that the LLM with custom config is used during outcome collapsing."""
|
||||
from crewai.llm import LLM
|
||||
|
||||
llm_instance = LLM(model="gpt-4o-mini", temperature=0.42)
|
||||
collapse_calls = []
|
||||
|
||||
class FullFlow(Flow):
|
||||
@start()
|
||||
@human_feedback(
|
||||
message="Review:",
|
||||
emit=["approved", "rejected"],
|
||||
llm=llm_instance,
|
||||
)
|
||||
def review(self):
|
||||
return "content"
|
||||
|
||||
@listen("approved")
|
||||
def on_approved(self):
|
||||
return "done"
|
||||
|
||||
flow = FullFlow()
|
||||
|
||||
original_collapse = flow._collapse_to_outcome
|
||||
|
||||
def spy_collapse(feedback, outcomes, llm):
|
||||
collapse_calls.append(llm)
|
||||
return "approved"
|
||||
|
||||
with (
|
||||
patch.object(flow, "_request_human_feedback", return_value="looks good"),
|
||||
patch.object(flow, "_collapse_to_outcome", side_effect=spy_collapse),
|
||||
):
|
||||
flow.kickoff()
|
||||
|
||||
assert len(collapse_calls) == 1
|
||||
# The LLM passed to _collapse_to_outcome should be the original instance
|
||||
assert collapse_calls[0] is llm_instance
|
||||
|
||||
Reference in New Issue
Block a user