From f13d307534d3a09bc7adf4af22fb1e70f5af9272 Mon Sep 17 00:00:00 2001 From: Greyson LaLonde Date: Fri, 20 Mar 2026 16:04:52 -0400 Subject: [PATCH 1/2] fix: pass cache_function from BaseTool to CrewStructuredTool --- lib/crewai/src/crewai/tools/base_tool.py | 1 + .../src/crewai/tools/structured_tool.py | 27 ++++++++----- .../tests/tools/test_structured_tool.py | 38 +++++++++++++++++++ 3 files changed, 56 insertions(+), 10 deletions(-) diff --git a/lib/crewai/src/crewai/tools/base_tool.py b/lib/crewai/src/crewai/tools/base_tool.py index 37e9fba09..118fa307b 100644 --- a/lib/crewai/src/crewai/tools/base_tool.py +++ b/lib/crewai/src/crewai/tools/base_tool.py @@ -281,6 +281,7 @@ class BaseTool(BaseModel, ABC): result_as_answer=self.result_as_answer, max_usage_count=self.max_usage_count, current_usage_count=self.current_usage_count, + cache_function=self.cache_function, ) structured_tool._original_tool = self return structured_tool diff --git a/lib/crewai/src/crewai/tools/structured_tool.py b/lib/crewai/src/crewai/tools/structured_tool.py index 4b95caeb7..60a457f3b 100644 --- a/lib/crewai/src/crewai/tools/structured_tool.py +++ b/lib/crewai/src/crewai/tools/structured_tool.py @@ -58,6 +58,7 @@ class CrewStructuredTool: result_as_answer: bool = False, max_usage_count: int | None = None, current_usage_count: int = 0, + cache_function: Callable[..., bool] | None = None, ) -> None: """Initialize the structured tool. @@ -69,6 +70,7 @@ class CrewStructuredTool: result_as_answer: Whether to return the output directly max_usage_count: Maximum number of times this tool can be used. None means unlimited usage. current_usage_count: Current number of times this tool has been used. + cache_function: Function to determine if the tool result should be cached. """ self.name = name self.description = description @@ -78,6 +80,7 @@ class CrewStructuredTool: self.result_as_answer = result_as_answer self.max_usage_count = max_usage_count self.current_usage_count = current_usage_count + self.cache_function = cache_function self._original_tool: BaseTool | None = None # Validate the function signature matches the schema @@ -86,7 +89,7 @@ class CrewStructuredTool: @classmethod def from_function( cls, - func: Callable, + func: Callable[..., Any], name: str | None = None, description: str | None = None, return_direct: bool = False, @@ -147,7 +150,7 @@ class CrewStructuredTool: @staticmethod def _create_schema_from_function( name: str, - func: Callable, + func: Callable[..., Any], ) -> type[BaseModel]: """Create a Pydantic schema from a function's signature. @@ -182,7 +185,7 @@ class CrewStructuredTool: # Create model schema_name = f"{name.title()}Schema" - return create_model(schema_name, **fields) # type: ignore[call-overload] + return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return] def _validate_function_signature(self) -> None: """Validate that the function signature matches the args schema.""" @@ -210,7 +213,7 @@ class CrewStructuredTool: f"not found in args_schema" ) - def _parse_args(self, raw_args: str | dict) -> dict: + def _parse_args(self, raw_args: str | dict[str, Any]) -> dict[str, Any]: """Parse and validate the input arguments against the schema. Args: @@ -234,8 +237,8 @@ class CrewStructuredTool: async def ainvoke( self, - input: str | dict, - config: dict | None = None, + input: str | dict[str, Any], + config: dict[str, Any] | None = None, **kwargs: Any, ) -> Any: """Asynchronously invoke the tool. @@ -269,7 +272,7 @@ class CrewStructuredTool: except Exception: raise - def _run(self, *args, **kwargs) -> Any: + def _run(self, *args: Any, **kwargs: Any) -> Any: """Legacy method for compatibility.""" # Convert args/kwargs to our expected format input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False)) @@ -277,7 +280,10 @@ class CrewStructuredTool: return self.invoke(input_dict) def invoke( - self, input: str | dict, config: dict | None = None, **kwargs: Any + self, + input: str | dict[str, Any], + config: dict[str, Any] | None = None, + **kwargs: Any, ) -> Any: """Main method for tool execution.""" parsed_args = self._parse_args(input) @@ -313,9 +319,10 @@ class CrewStructuredTool: self._original_tool.current_usage_count = self.current_usage_count @property - def args(self) -> dict: + def args(self) -> dict[str, Any]: """Get the tool's input arguments schema.""" - return self.args_schema.model_json_schema()["properties"] + schema: dict[str, Any] = self.args_schema.model_json_schema()["properties"] + return schema def __repr__(self) -> str: return f"CrewStructuredTool(name='{sanitize_tool_name(self.name)}', description='{self.description}')" diff --git a/lib/crewai/tests/tools/test_structured_tool.py b/lib/crewai/tests/tools/test_structured_tool.py index 999c13072..1cb8b3138 100644 --- a/lib/crewai/tests/tools/test_structured_tool.py +++ b/lib/crewai/tests/tools/test_structured_tool.py @@ -38,6 +38,44 @@ def test_initialization(basic_function, schema_class): assert tool.args_schema == schema_class +def test_cache_function_passed_through(basic_function, schema_class): + """Test that cache_function is stored on CrewStructuredTool.""" + + def no_cache(_args: dict, _result: str) -> bool: + return False + + tool = CrewStructuredTool( + name="test_tool", + description="Test tool description", + func=basic_function, + args_schema=schema_class, + cache_function=no_cache, + ) + + assert tool.cache_function is no_cache + + +def test_base_tool_passes_cache_function_to_structured_tool(): + """Test that BaseTool.to_structured_tool propagates cache_function.""" + from crewai.tools import BaseTool + + def no_cache(_args: dict, _result: str) -> bool: + return False + + class MyCacheTool(BaseTool): + name: str = "cache_test" + description: str = "tool for testing cache passthrough" + + def _run(self, query: str = "") -> str: + return "result" + + my_tool = MyCacheTool() + my_tool.cache_function = no_cache # type: ignore[assignment] + structured = my_tool.to_structured_tool() + + assert structured.cache_function is no_cache + + def test_from_function(basic_function): """Test creating tool from function""" tool = CrewStructuredTool.from_function( From 09b84dd2b032cc2f2cdc6b6f4751ede436fc9005 Mon Sep 17 00:00:00 2001 From: alex-clawd Date: Fri, 20 Mar 2026 14:42:28 -0700 Subject: [PATCH 2/2] fix: preserve full LLM config across HITL resume for non-OpenAI providers (#4970) When a flow with @human_feedback(llm=create_llm()) pauses for HITL and later resumes: 1. The LLM object was being serialized to just a model string via _serialize_llm_for_context() (e.g. 'gemini/gemini-3.1-flash-lite-preview') 2. On resume, resume_async() was creating LLM(model=string) with NO credentials, project, location, safety_settings, or client_params 3. OpenAI worked by accident (OPENAI_API_KEY from env), but Gemini with service accounts broke This fix: - Stashes the live LLM object on the wrapper as _hf_llm attribute - On resume, looks up the method and retrieves the live LLM if available - Falls back to the serialized string for backward compatibility - Preserves _hf_llm through FlowMethod wrapper decorators Co-authored-by: Joao Moura Co-authored-by: Claude Opus 4.5 --- lib/crewai/src/crewai/flow/flow.py | 20 +- lib/crewai/src/crewai/flow/flow_wrappers.py | 1 + lib/crewai/src/crewai/flow/human_feedback.py | 8 + lib/crewai/tests/test_async_human_feedback.py | 272 ++++++++++++++++++ 4 files changed, 300 insertions(+), 1 deletion(-) diff --git a/lib/crewai/src/crewai/flow/flow.py b/lib/crewai/src/crewai/flow/flow.py index 99c5edab4..a04324462 100644 --- a/lib/crewai/src/crewai/flow/flow.py +++ b/lib/crewai/src/crewai/flow/flow.py @@ -1315,7 +1315,25 @@ class Flow(Generic[T], metaclass=FlowMeta): context = self._pending_feedback_context emit = context.emit default_outcome = context.default_outcome - llm = context.llm + + # Try to get the live LLM from the re-imported decorator instead of the + # serialized string. When a flow pauses for HITL and resumes (possibly in + # a different process), context.llm only contains a model string like + # 'gemini/gemini-3-flash-preview'. This loses credentials, project, + # location, safety_settings, and client_params. By looking up the method + # on the re-imported flow class, we can retrieve the fully-configured LLM + # that was passed to the @human_feedback decorator. + llm = context.llm # fallback to serialized string + method = self._methods.get(FlowMethodName(context.method_name)) + if method is not None: + live_llm = getattr(method, "_hf_llm", None) + if live_llm is not None: + from crewai.llms.base_llm import BaseLLM as BaseLLMClass + + # Only use live LLM if it's a BaseLLM instance (not a string) + # String values offer no benefit over the serialized context.llm + if isinstance(live_llm, BaseLLMClass): + llm = live_llm # Determine outcome collapsed_outcome: str | None = None diff --git a/lib/crewai/src/crewai/flow/flow_wrappers.py b/lib/crewai/src/crewai/flow/flow_wrappers.py index ace2fe727..3eaa67699 100644 --- a/lib/crewai/src/crewai/flow/flow_wrappers.py +++ b/lib/crewai/src/crewai/flow/flow_wrappers.py @@ -75,6 +75,7 @@ class FlowMethod(Generic[P, R]): "__is_router__", "__router_paths__", "__human_feedback_config__", + "_hf_llm", # Live LLM object for HITL resume ]: if hasattr(meth, attr): setattr(self, attr, getattr(meth, attr)) diff --git a/lib/crewai/src/crewai/flow/human_feedback.py b/lib/crewai/src/crewai/flow/human_feedback.py index 7389b8a9e..61e99fce5 100644 --- a/lib/crewai/src/crewai/flow/human_feedback.py +++ b/lib/crewai/src/crewai/flow/human_feedback.py @@ -572,6 +572,14 @@ def human_feedback( wrapper.__is_router__ = True wrapper.__router_paths__ = list(emit) + # Stash the live LLM object for HITL resume to retrieve. + # When a flow pauses for human feedback and later resumes (possibly in a + # different process), the serialized context only contains a model string. + # By storing the original LLM on the wrapper, resume_async can retrieve + # the fully-configured LLM (with credentials, project, safety_settings, etc.) + # instead of creating a bare LLM from just the model string. + wrapper._hf_llm = llm + return wrapper # type: ignore[no-any-return] return decorator diff --git a/lib/crewai/tests/test_async_human_feedback.py b/lib/crewai/tests/test_async_human_feedback.py index f4977858b..3fc222387 100644 --- a/lib/crewai/tests/test_async_human_feedback.py +++ b/lib/crewai/tests/test_async_human_feedback.py @@ -1216,3 +1216,275 @@ class TestAsyncHumanFeedbackEdgeCases: assert flow.last_human_feedback.outcome == "approved" assert flow.last_human_feedback.feedback == "" + + +# ============================================================================= +# Tests for _hf_llm attribute and live LLM resolution on resume +# ============================================================================= + + +class TestLiveLLMPreservationOnResume: + """Tests for preserving the full LLM config across HITL resume.""" + + def test_hf_llm_attribute_set_on_wrapper_with_basellm(self) -> None: + """Test that _hf_llm is set on the wrapper when llm is a BaseLLM instance.""" + from crewai.llms.base_llm import BaseLLM + + # Create a mock BaseLLM object + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.model = "gemini/gemini-3-flash" + + class TestFlow(Flow): + @start() + @human_feedback( + message="Review:", + emit=["approved", "rejected"], + llm=mock_llm, + ) + def review(self): + return "content" + + flow = TestFlow() + method = flow._methods.get("review") + assert method is not None + assert hasattr(method, "_hf_llm") + assert method._hf_llm is mock_llm + + def test_hf_llm_attribute_set_on_wrapper_with_string(self) -> None: + """Test that _hf_llm is set on the wrapper even when llm is a string.""" + + class TestFlow(Flow): + @start() + @human_feedback( + message="Review:", + emit=["approved", "rejected"], + llm="gpt-4o-mini", + ) + def review(self): + return "content" + + flow = TestFlow() + method = flow._methods.get("review") + assert method is not None + assert hasattr(method, "_hf_llm") + assert method._hf_llm == "gpt-4o-mini" + + @patch("crewai.flow.flow.crewai_event_bus.emit") + def test_resume_async_uses_live_basellm_over_serialized_string( + self, mock_emit: MagicMock + ) -> None: + """Test that resume_async uses the live BaseLLM from decorator instead of serialized string. + + This is the main bug fix: when a flow resumes, it should use the fully-configured + LLM from the re-imported decorator (with credentials, project, etc.) instead of + creating a new LLM from just the model string. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + from crewai.llms.base_llm import BaseLLM + + # Create a mock BaseLLM with full config (simulating Gemini with service account) + live_llm = MagicMock(spec=BaseLLM) + live_llm.model = "gemini/gemini-3-flash" + + class TestFlow(Flow): + result_path: str = "" + + @start() + @human_feedback( + message="Approve?", + emit=["approved", "rejected"], + llm=live_llm, # Full LLM object with credentials + ) + def review(self): + return "content" + + @listen("approved") + def handle_approved(self): + self.result_path = "approved" + return "Approved!" + + # Save pending feedback with just a model STRING (simulating serialization) + context = PendingFeedbackContext( + flow_id="live-llm-test", + flow_class="TestFlow", + method_name="review", + method_output="content", + message="Approve?", + emit=["approved", "rejected"], + llm="gemini/gemini-3-flash", # Serialized string, NOT the live object + ) + persistence.save_pending_feedback( + flow_uuid="live-llm-test", + context=context, + state_data={"id": "live-llm-test"}, + ) + + # Restore flow - this re-imports the class with the live LLM + flow = TestFlow.from_pending("live-llm-test", persistence) + + # Mock _collapse_to_outcome to capture what LLM it receives + captured_llm = [] + + def capture_llm(feedback, outcomes, llm): + captured_llm.append(llm) + return "approved" + + with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm): + flow.resume("looks good!") + + # The key assertion: _collapse_to_outcome received the LIVE BaseLLM object, + # NOT the serialized string. The live_llm was captured at class definition + # time and stored on the method wrapper as _hf_llm. + assert len(captured_llm) == 1 + # Verify it's the same object that was passed to the decorator + # (which is stored on the method's _hf_llm attribute) + method = flow._methods.get("review") + assert method is not None + assert captured_llm[0] is method._hf_llm + # And verify it's a BaseLLM instance, not a string + assert isinstance(captured_llm[0], BaseLLM) + + @patch("crewai.flow.flow.crewai_event_bus.emit") + def test_resume_async_falls_back_to_serialized_string_when_no_hf_llm( + self, mock_emit: MagicMock + ) -> None: + """Test that resume_async falls back to context.llm when _hf_llm is not available. + + This ensures backward compatibility with flows that were paused before this fix. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class TestFlow(Flow): + @start() + @human_feedback( + message="Approve?", + emit=["approved", "rejected"], + llm="gpt-4o-mini", + ) + def review(self): + return "content" + + # Save pending feedback + context = PendingFeedbackContext( + flow_id="fallback-test", + flow_class="TestFlow", + method_name="review", + method_output="content", + message="Approve?", + emit=["approved", "rejected"], + llm="gpt-4o-mini", + ) + persistence.save_pending_feedback( + flow_uuid="fallback-test", + context=context, + state_data={"id": "fallback-test"}, + ) + + flow = TestFlow.from_pending("fallback-test", persistence) + + # Remove _hf_llm to simulate old decorator without this attribute + method = flow._methods.get("review") + if hasattr(method, "_hf_llm"): + delattr(method, "_hf_llm") + + # Mock _collapse_to_outcome to capture what LLM it receives + captured_llm = [] + + def capture_llm(feedback, outcomes, llm): + captured_llm.append(llm) + return "approved" + + with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm): + flow.resume("looks good!") + + # Should fall back to the serialized string + assert len(captured_llm) == 1 + assert captured_llm[0] == "gpt-4o-mini" + + @patch("crewai.flow.flow.crewai_event_bus.emit") + def test_resume_async_uses_string_from_context_when_hf_llm_is_string( + self, mock_emit: MagicMock + ) -> None: + """Test that when _hf_llm is a string (not BaseLLM), we still use context.llm. + + String LLM values offer no benefit over the serialized context.llm, + so we don't prefer them. + """ + with tempfile.TemporaryDirectory() as tmpdir: + db_path = os.path.join(tmpdir, "test_flows.db") + persistence = SQLiteFlowPersistence(db_path) + + class TestFlow(Flow): + @start() + @human_feedback( + message="Approve?", + emit=["approved", "rejected"], + llm="gpt-4o-mini", # String LLM + ) + def review(self): + return "content" + + # Save pending feedback + context = PendingFeedbackContext( + flow_id="string-llm-test", + flow_class="TestFlow", + method_name="review", + method_output="content", + message="Approve?", + emit=["approved", "rejected"], + llm="gpt-4o-mini", + ) + persistence.save_pending_feedback( + flow_uuid="string-llm-test", + context=context, + state_data={"id": "string-llm-test"}, + ) + + flow = TestFlow.from_pending("string-llm-test", persistence) + + # Verify _hf_llm is a string + method = flow._methods.get("review") + assert method._hf_llm == "gpt-4o-mini" + + # Mock _collapse_to_outcome to capture what LLM it receives + captured_llm = [] + + def capture_llm(feedback, outcomes, llm): + captured_llm.append(llm) + return "approved" + + with patch.object(flow, "_collapse_to_outcome", side_effect=capture_llm): + flow.resume("looks good!") + + # Should use context.llm since _hf_llm is a string (not BaseLLM) + assert len(captured_llm) == 1 + assert captured_llm[0] == "gpt-4o-mini" + + def test_hf_llm_set_for_async_wrapper(self) -> None: + """Test that _hf_llm is set on async wrapper functions.""" + import asyncio + from crewai.llms.base_llm import BaseLLM + + mock_llm = MagicMock(spec=BaseLLM) + mock_llm.model = "gemini/gemini-3-flash" + + class TestFlow(Flow): + @start() + @human_feedback( + message="Review:", + emit=["approved", "rejected"], + llm=mock_llm, + ) + async def async_review(self): + return "content" + + flow = TestFlow() + method = flow._methods.get("async_review") + assert method is not None + assert hasattr(method, "_hf_llm") + assert method._hf_llm is mock_llm