mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-05 17:22:36 +00:00
fix: prevent shared LLM stop words mutation across agents
Some checks failed
Some checks failed
This commit is contained in:
@@ -1102,16 +1102,6 @@ class Agent(BaseAgent):
|
||||
self.agent_executor.tools_handler = self.tools_handler
|
||||
self.agent_executor.request_within_rpm_limit = rpm_limit_fn
|
||||
|
||||
if isinstance(self.agent_executor.llm, BaseLLM):
|
||||
existing_stop = getattr(self.agent_executor.llm, "stop", [])
|
||||
self.agent_executor.llm.stop = list(
|
||||
set(
|
||||
existing_stop + stop_words
|
||||
if isinstance(existing_stop, list)
|
||||
else stop_words
|
||||
)
|
||||
)
|
||||
|
||||
def get_delegation_tools(self, agents: Sequence[BaseAgent]) -> list[BaseTool]:
|
||||
agent_tools = AgentTools(agents=agents)
|
||||
return agent_tools.tools()
|
||||
|
||||
@@ -49,6 +49,7 @@ from crewai.hooks.tool_hooks import (
|
||||
)
|
||||
from crewai.types.callback import SerializableCallable
|
||||
from crewai.utilities.agent_utils import (
|
||||
_llm_stop_words_applied,
|
||||
aget_llm_response,
|
||||
convert_tools_to_openai_schema,
|
||||
enforce_rpm_limit,
|
||||
@@ -141,15 +142,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
if not self.after_llm_call_hooks:
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm and not isinstance(self.llm, str):
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
self.llm.stop = list(
|
||||
set(
|
||||
existing_stop + self.stop
|
||||
if isinstance(existing_stop, list)
|
||||
else self.stop
|
||||
)
|
||||
)
|
||||
|
||||
@property
|
||||
def use_stop_words(self) -> bool:
|
||||
@@ -210,21 +202,22 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
try:
|
||||
formatted_answer = self._invoke_loop()
|
||||
except AssertionError:
|
||||
if self.agent.verbose:
|
||||
PRINTER.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(PRINTER, e, verbose=self.agent.verbose)
|
||||
raise
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
try:
|
||||
formatted_answer = self._invoke_loop()
|
||||
except AssertionError:
|
||||
if self.agent.verbose:
|
||||
PRINTER.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(PRINTER, e, verbose=self.agent.verbose)
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._save_to_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
@@ -1082,21 +1075,22 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
try:
|
||||
formatted_answer = await self._ainvoke_loop()
|
||||
except AssertionError:
|
||||
if self.agent.verbose:
|
||||
PRINTER.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(PRINTER, e, verbose=self.agent.verbose)
|
||||
raise
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
try:
|
||||
formatted_answer = await self._ainvoke_loop()
|
||||
except AssertionError:
|
||||
if self.agent.verbose:
|
||||
PRINTER.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(PRINTER, e, verbose=self.agent.verbose)
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
|
||||
self._save_to_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
@@ -71,6 +71,7 @@ from crewai.hooks.types import (
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.agent_utils import (
|
||||
_llm_stop_words_applied,
|
||||
check_native_tool_support,
|
||||
enforce_rpm_limit,
|
||||
extract_tool_call_info,
|
||||
@@ -215,12 +216,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
|
||||
if self.llm:
|
||||
existing_stop = getattr(self.llm, "stop", [])
|
||||
if not isinstance(existing_stop, list):
|
||||
existing_stop = []
|
||||
self.llm.stop = list(set(existing_stop + self.stop_words))
|
||||
|
||||
self._state = AgentExecutorState()
|
||||
self.max_method_calls = self.max_iter * 10
|
||||
|
||||
@@ -2601,17 +2596,18 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
inputs.get("ask_for_human_input", False)
|
||||
)
|
||||
|
||||
self.kickoff()
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
self.kickoff()
|
||||
|
||||
formatted_answer = self.state.current_answer
|
||||
formatted_answer = self.state.current_answer
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer."
|
||||
)
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer."
|
||||
)
|
||||
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._save_to_memory(formatted_answer)
|
||||
|
||||
@@ -2691,18 +2687,20 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
inputs.get("ask_for_human_input", False)
|
||||
)
|
||||
|
||||
# Use async kickoff directly since we're already in an async context
|
||||
await self.kickoff_async()
|
||||
with _llm_stop_words_applied(self.llm, self):
|
||||
await self.kickoff_async()
|
||||
|
||||
formatted_answer = self.state.current_answer
|
||||
formatted_answer = self.state.current_answer
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer."
|
||||
)
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer."
|
||||
)
|
||||
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(formatted_answer)
|
||||
if self.state.ask_for_human_input:
|
||||
formatted_answer = await self._ahandle_human_feedback(
|
||||
formatted_answer
|
||||
)
|
||||
|
||||
self._save_to_memory(formatted_answer)
|
||||
|
||||
|
||||
@@ -688,7 +688,9 @@ class LLM(BaseLLM):
|
||||
"temperature": self.temperature,
|
||||
"top_p": self.top_p,
|
||||
"n": self.n,
|
||||
"stop": (self.stop or None) if self.supports_stop_words() else None,
|
||||
"stop": (self.stop_sequences or None)
|
||||
if self.supports_stop_words()
|
||||
else None,
|
||||
"max_tokens": self.max_tokens or self.max_completion_tokens,
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
|
||||
@@ -72,6 +72,9 @@ _JSON_EXTRACTION_PATTERN: Final[re.Pattern[str]] = re.compile(r"\{.*}", re.DOTAL
|
||||
_current_call_id: contextvars.ContextVar[str | None] = contextvars.ContextVar(
|
||||
"_current_call_id", default=None
|
||||
)
|
||||
_call_stop_override_var: contextvars.ContextVar[dict[int, list[str]] | None] = (
|
||||
contextvars.ContextVar("_call_stop_override_var", default=None)
|
||||
)
|
||||
|
||||
|
||||
@contextmanager
|
||||
@@ -85,6 +88,31 @@ def llm_call_context() -> Generator[str, None, None]:
|
||||
_current_call_id.reset(token)
|
||||
|
||||
|
||||
@contextmanager
|
||||
def call_stop_override(
|
||||
llm: BaseLLM, stop: list[str] | None
|
||||
) -> Generator[None, None, None]:
|
||||
"""Override the stop list for ``llm`` within the current call scope.
|
||||
|
||||
Only ``llm``'s reads via :attr:`BaseLLM.stop_sequences` see ``stop``;
|
||||
other LLM instances (e.g. an agent's ``function_calling_llm``) keep their
|
||||
own ``stop`` field. Passing ``None`` clears any prior override for ``llm``
|
||||
in the same scope. The instance-level ``stop`` field is never mutated,
|
||||
so the override is safe under concurrent execution.
|
||||
"""
|
||||
current = _call_stop_override_var.get()
|
||||
new_overrides: dict[int, list[str]] = dict(current) if current else {}
|
||||
if stop is None:
|
||||
new_overrides.pop(id(llm), None)
|
||||
else:
|
||||
new_overrides[id(llm)] = stop
|
||||
token = _call_stop_override_var.set(new_overrides)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
_call_stop_override_var.reset(token)
|
||||
|
||||
|
||||
def get_current_call_id() -> str:
|
||||
"""Get current call_id from context"""
|
||||
call_id = _current_call_id.get()
|
||||
@@ -158,11 +186,18 @@ class BaseLLM(BaseModel, ABC):
|
||||
|
||||
@property
|
||||
def stop_sequences(self) -> list[str]:
|
||||
"""Alias for ``stop`` — kept for backward compatibility with provider APIs.
|
||||
"""Stop list active for the current call.
|
||||
|
||||
Writes are handled by ``__setattr__``, which normalizes and redirects
|
||||
``stop_sequences`` assignments to the ``stop`` field.
|
||||
Returns the per-instance override set via :func:`call_stop_override`
|
||||
when one is in effect for this LLM; otherwise the instance-level
|
||||
``stop`` field. Kept under this name for backward compatibility with
|
||||
provider APIs that already read ``stop_sequences``.
|
||||
"""
|
||||
overrides = _call_stop_override_var.get()
|
||||
if overrides is not None:
|
||||
override = overrides.get(id(self))
|
||||
if override is not None:
|
||||
return override
|
||||
return self.stop
|
||||
|
||||
_token_usage: dict[str, int] = PrivateAttr(
|
||||
@@ -341,7 +376,7 @@ class BaseLLM(BaseModel, ABC):
|
||||
Returns:
|
||||
True if stop words are configured and can be applied
|
||||
"""
|
||||
return bool(self.stop)
|
||||
return bool(self.stop_sequences)
|
||||
|
||||
def _apply_stop_words(self, content: str) -> str:
|
||||
"""Apply stop words to truncate response content.
|
||||
@@ -363,14 +398,14 @@ class BaseLLM(BaseModel, ABC):
|
||||
>>> llm._apply_stop_words(response)
|
||||
"I need to search.\\n\\nAction: search"
|
||||
"""
|
||||
if not self.stop or not content:
|
||||
stops = self.stop_sequences
|
||||
if not stops or not content:
|
||||
return content
|
||||
|
||||
# Find the earliest occurrence of any stop word
|
||||
earliest_stop_pos = len(content)
|
||||
found_stop_word = None
|
||||
|
||||
for stop_word in self.stop:
|
||||
for stop_word in stops:
|
||||
stop_pos = content.find(stop_word)
|
||||
if stop_pos != -1 and stop_pos < earliest_stop_pos:
|
||||
earliest_stop_pos = stop_pos
|
||||
|
||||
@@ -679,8 +679,9 @@ class AzureCompletion(BaseLLM):
|
||||
params["presence_penalty"] = self.presence_penalty
|
||||
if self.max_tokens is not None:
|
||||
params["max_tokens"] = self.max_tokens
|
||||
if self.stop and self.supports_stop_words():
|
||||
params["stop"] = self.stop
|
||||
stops = self.stop_sequences
|
||||
if stops and self.supports_stop_words():
|
||||
params["stop"] = stops
|
||||
|
||||
# Handle tools/functions for Azure OpenAI models
|
||||
if tools and self.is_openai_model:
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Sequence
|
||||
from collections.abc import Callable, Iterator, Sequence
|
||||
import concurrent.futures
|
||||
import contextlib
|
||||
import contextvars
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import datetime
|
||||
@@ -22,7 +23,7 @@ from crewai.agents.parser import (
|
||||
parse,
|
||||
)
|
||||
from crewai.cli.config import Settings
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.llms.base_llm import BaseLLM, call_stop_override
|
||||
from crewai.tools import BaseTool as CrewAITool
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
@@ -238,6 +239,38 @@ def extract_task_section(text: str) -> str:
|
||||
return text
|
||||
|
||||
|
||||
def _executor_stop_words(
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||
) -> list[str]:
|
||||
"""Return the executor's stop words, regardless of which field name it uses."""
|
||||
if executor_context is None:
|
||||
return []
|
||||
stops = getattr(executor_context, "stop", None)
|
||||
if stops is None:
|
||||
stops = getattr(executor_context, "stop_words", None)
|
||||
return list(stops) if stops else []
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def _llm_stop_words_applied(
|
||||
llm: LLM | BaseLLM,
|
||||
executor_context: CrewAgentExecutor | AgentExecutor | LiteAgent | None,
|
||||
) -> Iterator[None]:
|
||||
"""Apply the executor's stop words to the LLM for the duration of one call.
|
||||
|
||||
Uses :func:`crewai.llms.base_llm.call_stop_override` so the LLM's stop
|
||||
field is never mutated. Safe under concurrent execution: the override is
|
||||
propagated via a :class:`contextvars.ContextVar` and is scoped to this
|
||||
call's task / thread context.
|
||||
"""
|
||||
extra = _executor_stop_words(executor_context)
|
||||
if not extra or not isinstance(llm, BaseLLM) or set(extra).issubset(llm.stop):
|
||||
yield
|
||||
return
|
||||
with call_stop_override(llm, list(set(llm.stop + extra))):
|
||||
yield
|
||||
|
||||
|
||||
def has_reached_max_iterations(iterations: int, max_iterations: int) -> bool:
|
||||
"""Check if the maximum number of iterations has been reached.
|
||||
|
||||
@@ -459,18 +492,15 @@ def get_llm_response(
|
||||
"""
|
||||
messages = _prepare_llm_call(executor_context, messages, printer, verbose=verbose)
|
||||
|
||||
try:
|
||||
answer = llm.call(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
answer = llm.call(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return _validate_and_finalize_llm_response(
|
||||
answer, executor_context, printer, verbose=verbose
|
||||
@@ -515,18 +545,15 @@ async def aget_llm_response(
|
||||
"""
|
||||
messages = _prepare_llm_call(executor_context, messages, printer, verbose=verbose)
|
||||
|
||||
try:
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
tools=tools,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return _validate_and_finalize_llm_response(
|
||||
answer, executor_context, printer, verbose=verbose
|
||||
|
||||
@@ -2452,3 +2452,167 @@ def test_agent_mcps_accepts_legacy_prefix_with_tool():
|
||||
mcps=["crewai-amp:notion#get_page"],
|
||||
)
|
||||
assert agent.mcps == ["crewai-amp:notion#get_page"]
|
||||
|
||||
|
||||
class TestSharedLLMStopWords:
|
||||
"""Regression tests for shared LLM stop words mutation (issue #5141).
|
||||
|
||||
Stop words from one executor must not leak into the shared LLM permanently
|
||||
or pollute other agents sharing that LLM.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def _make_executor(llm: LLM, stop_words: list[str]) -> CrewAgentExecutor:
|
||||
"""Build a CrewAgentExecutor with minimal deps."""
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
|
||||
agent = Agent(role="r", goal="g", backstory="b")
|
||||
task = Task(description="d", expected_output="o", agent=agent)
|
||||
return CrewAgentExecutor(
|
||||
agent=agent,
|
||||
task=task,
|
||||
llm=llm,
|
||||
crew=None,
|
||||
prompt={"prompt": "p {input} {tool_names} {tools}"},
|
||||
max_iter=5,
|
||||
tools=[],
|
||||
tools_names="",
|
||||
stop_words=stop_words,
|
||||
tools_description="",
|
||||
tools_handler=ToolsHandler(),
|
||||
)
|
||||
|
||||
def test_executor_init_does_not_mutate_shared_llm(self) -> None:
|
||||
"""Constructing executors must not touch the shared LLM's stop list."""
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
original = list(shared.stop)
|
||||
|
||||
a = self._make_executor(shared, stop_words=["StopA:"])
|
||||
b = self._make_executor(shared, stop_words=["StopB:"])
|
||||
|
||||
assert shared.stop == original
|
||||
assert a.llm is shared
|
||||
assert b.llm is shared
|
||||
|
||||
def test_effective_stop_reflects_override_inside_context(self) -> None:
|
||||
"""Inside the helper, the effective stop list includes the executor's words."""
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
executor = self._make_executor(shared, stop_words=["Observation:"])
|
||||
|
||||
with _llm_stop_words_applied(shared, executor):
|
||||
assert set(shared.stop_sequences) == {"Original:", "Observation:"}
|
||||
assert shared.stop == ["Original:"]
|
||||
|
||||
assert shared.stop == ["Original:"]
|
||||
assert shared.stop_sequences == ["Original:"]
|
||||
|
||||
def test_override_cleared_when_context_raises(self) -> None:
|
||||
"""A failed call must still clear the per-call stop override."""
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
executor = self._make_executor(shared, stop_words=["Observation:"])
|
||||
|
||||
try:
|
||||
with _llm_stop_words_applied(shared, executor):
|
||||
raise RuntimeError("boom")
|
||||
except RuntimeError:
|
||||
pass
|
||||
|
||||
assert shared.stop == ["Original:"]
|
||||
assert shared.stop_sequences == ["Original:"]
|
||||
|
||||
def test_override_applies_for_post_processing_when_api_lacks_stop_support(
|
||||
self,
|
||||
) -> None:
|
||||
"""Models that lack API-level stop support still need the override.
|
||||
|
||||
Native providers (e.g. Azure on gpt-5/o-series) read ``stop_sequences``
|
||||
in ``_apply_stop_words`` to truncate the response post-hoc even when
|
||||
``supports_stop_words()`` returns False, so the override must be set
|
||||
regardless of API-level support. (Issue raised by Cursor Bugbot.)
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
executor = self._make_executor(shared, stop_words=["Observation:"])
|
||||
|
||||
with patch.object(shared, "supports_stop_words", return_value=False):
|
||||
with _llm_stop_words_applied(shared, executor):
|
||||
assert set(shared.stop_sequences) == {"Original:", "Observation:"}
|
||||
|
||||
assert shared.stop == ["Original:"]
|
||||
assert shared.stop_sequences == ["Original:"]
|
||||
|
||||
def test_concurrent_overrides_do_not_collide(self) -> None:
|
||||
"""Concurrent agents on a shared LLM must each see their own effective stop."""
|
||||
import asyncio
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
exec_a = self._make_executor(shared, stop_words=["StopA:"])
|
||||
exec_b = self._make_executor(shared, stop_words=["StopB:"])
|
||||
|
||||
async def run(executor: CrewAgentExecutor, expected: str) -> set[str]:
|
||||
with _llm_stop_words_applied(shared, executor):
|
||||
await asyncio.sleep(0)
|
||||
seen = set(shared.stop_sequences)
|
||||
assert expected in seen
|
||||
return seen
|
||||
|
||||
async def main() -> tuple[set[str], set[str]]:
|
||||
return await asyncio.gather(
|
||||
run(exec_a, "StopA:"), run(exec_b, "StopB:")
|
||||
)
|
||||
|
||||
a_seen, b_seen = asyncio.run(main())
|
||||
assert a_seen == {"Original:", "StopA:"}
|
||||
assert b_seen == {"Original:", "StopB:"}
|
||||
assert shared.stop == ["Original:"]
|
||||
assert shared.stop_sequences == ["Original:"]
|
||||
|
||||
def test_override_does_not_leak_to_other_llm_instances(self) -> None:
|
||||
"""Override for one LLM must not affect another LLM (e.g. function_calling_llm).
|
||||
|
||||
Regression for Cursor Bugbot: a global ContextVar would leak the
|
||||
override to every BaseLLM that reads stop_sequences during the scope.
|
||||
"""
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
target = LLM(model="gpt-4", stop=["TargetStop:"])
|
||||
other = LLM(model="gpt-4", stop=["OtherStop:"])
|
||||
executor = self._make_executor(target, stop_words=["Observation:"])
|
||||
|
||||
with _llm_stop_words_applied(target, executor):
|
||||
assert set(target.stop_sequences) == {"TargetStop:", "Observation:"}
|
||||
assert other.stop_sequences == ["OtherStop:"]
|
||||
|
||||
assert target.stop_sequences == ["TargetStop:"]
|
||||
assert other.stop_sequences == ["OtherStop:"]
|
||||
|
||||
def test_override_propagates_to_nested_direct_llm_calls(self) -> None:
|
||||
"""Once invoke wraps with the override, nested direct llm.call sites
|
||||
(StepExecutor, handle_max_iterations_exceeded) see the merged stops.
|
||||
|
||||
Regression for Cursor Bugbot: those direct call sites bypass
|
||||
get_llm_response, so the override must be set at executor entry, not
|
||||
only around get_llm_response.
|
||||
"""
|
||||
from crewai.utilities.agent_utils import _llm_stop_words_applied
|
||||
|
||||
shared = LLM(model="gpt-4", stop=["Original:"])
|
||||
executor = self._make_executor(shared, stop_words=["Observation:"])
|
||||
|
||||
seen: list[set[str]] = []
|
||||
|
||||
def nested_direct_call() -> None:
|
||||
seen.append(set(shared.stop_sequences))
|
||||
|
||||
with _llm_stop_words_applied(shared, executor):
|
||||
nested_direct_call()
|
||||
|
||||
assert seen == [{"Original:", "Observation:"}]
|
||||
assert shared.stop == ["Original:"]
|
||||
|
||||
Reference in New Issue
Block a user