mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-06 01:32:36 +00:00
fix: distinguish explicit name via _name_was_explicit flag
This commit is contained in:
@@ -212,6 +212,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
default_factory=TaskOutputStorageHandler
|
||||
)
|
||||
_kickoff_event_id: str | None = PrivateAttr(default=None)
|
||||
_name_was_explicit: bool = PrivateAttr(default=False)
|
||||
|
||||
name: str | None = Field(default=None)
|
||||
cache: bool = Field(default=True)
|
||||
@@ -549,7 +550,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
@model_validator(mode="after")
|
||||
def _resolve_name(self) -> Self:
|
||||
"""Fall back to the class name when no explicit name is provided."""
|
||||
if self.name is None:
|
||||
# Snapshot whether `name` was user-provided before the assignment below
|
||||
# pollutes `model_fields_set`.
|
||||
self._name_was_explicit = "name" in self.model_fields_set
|
||||
if not self._name_was_explicit:
|
||||
self.name = type(self).__name__
|
||||
return self
|
||||
|
||||
|
||||
@@ -238,10 +238,10 @@ def crew(
|
||||
|
||||
crew_instance: Crew = _call_method(meth, self, *args, **kwargs)
|
||||
|
||||
# Override only when the Crew's name is the auto-resolved class-name
|
||||
# fallback, so an explicit `Crew(name=...)` inside the factory wins.
|
||||
# Propagate the @CrewBase class name only when the user didn't pass an
|
||||
# explicit `name=` to the Crew constructor inside the factory method.
|
||||
crewbase_name = getattr(self, "_crew_name", None)
|
||||
if crewbase_name and crew_instance.name == type(crew_instance).__name__:
|
||||
if crewbase_name and not crew_instance._name_was_explicit:
|
||||
crew_instance.name = crewbase_name
|
||||
|
||||
def callback_wrapper(
|
||||
|
||||
@@ -293,6 +293,33 @@ def test_crew_decorator_preserves_explicit_name():
|
||||
assert crew_instance.name == "My Explicit Name"
|
||||
|
||||
|
||||
def test_crew_decorator_preserves_explicit_name_matching_class_name():
|
||||
"""Explicit name that happens to equal the class-name fallback must still win."""
|
||||
sample_agent = Agent(role="r", goal="g", backstory="b")
|
||||
sample_task = Task(description="d", expected_output="o", agent=sample_agent)
|
||||
|
||||
@CrewBase
|
||||
class AmbiguousFactory:
|
||||
agents_config = None
|
||||
tasks_config = None
|
||||
agents: list[BaseAgent] = [sample_agent]
|
||||
tasks: list[Task] = [sample_task]
|
||||
|
||||
@crew
|
||||
def crew(self):
|
||||
# "Crew" is the class-name fallback the validator would use, but the
|
||||
# user set it explicitly — propagation must still skip the override.
|
||||
return Crew(
|
||||
name="Crew",
|
||||
agents=[sample_agent],
|
||||
tasks=[sample_task],
|
||||
)
|
||||
|
||||
factory_cls = cast(type[Any], AmbiguousFactory)
|
||||
crew_instance: Crew = cast(Any, factory_cls()).crew()
|
||||
assert crew_instance.name == "Crew"
|
||||
|
||||
|
||||
@tool
|
||||
def simple_tool():
|
||||
"""Return 'Hi!'"""
|
||||
|
||||
Reference in New Issue
Block a user