mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-04-22 02:42:37 +00:00
Compare commits
1 Commits
feat/impro
...
fix/agent-
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
b752986021 |
@@ -114,7 +114,6 @@ class Agent(BaseAgent):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
|
||||
self.llm = create_llm(self.llm)
|
||||
@@ -134,8 +133,11 @@ class Agent(BaseAgent):
|
||||
self.cache_handler = CacheHandler()
|
||||
self.set_cache_handler(self.cache_handler)
|
||||
|
||||
def _set_knowledge(self):
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
try:
|
||||
if self.embedder is None and crew_embedder:
|
||||
self.embedder = crew_embedder
|
||||
|
||||
if self.knowledge_sources:
|
||||
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
|
||||
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
|
||||
|
||||
@@ -351,3 +351,6 @@ class BaseAgent(ABC, BaseModel):
|
||||
if not self._rpm_controller:
|
||||
self._rpm_controller = rpm_controller
|
||||
self.create_agent_executor()
|
||||
|
||||
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||
pass
|
||||
|
||||
@@ -600,6 +600,7 @@ class Crew(BaseModel):
|
||||
agent.i18n = i18n
|
||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||
agent.crew = self # type: ignore[attr-defined]
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
|
||||
@@ -19,8 +19,6 @@ from typing import (
|
||||
Tuple,
|
||||
Type,
|
||||
Union,
|
||||
get_args,
|
||||
get_origin,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
@@ -174,29 +172,15 @@ class Task(BaseModel):
|
||||
"""
|
||||
if v is not None:
|
||||
sig = inspect.signature(v)
|
||||
positional_args = [
|
||||
param
|
||||
for param in sig.parameters.values()
|
||||
if param.default is inspect.Parameter.empty
|
||||
]
|
||||
if len(positional_args) != 1:
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError("Guardrail function must accept exactly one parameter")
|
||||
|
||||
# Check return annotation if present, but don't require it
|
||||
return_annotation = sig.return_annotation
|
||||
if return_annotation != inspect.Signature.empty:
|
||||
|
||||
return_annotation_args = get_args(return_annotation)
|
||||
if not (
|
||||
get_origin(return_annotation) is tuple
|
||||
and len(return_annotation_args) == 2
|
||||
and return_annotation_args[0] is bool
|
||||
and (
|
||||
return_annotation_args[1] is Any
|
||||
or return_annotation_args[1] is str
|
||||
or return_annotation_args[1] is TaskOutput
|
||||
or return_annotation_args[1] == Union[str, TaskOutput]
|
||||
)
|
||||
return_annotation == Tuple[bool, Any]
|
||||
or str(return_annotation) == "Tuple[bool, Any]"
|
||||
):
|
||||
raise ValueError(
|
||||
"If return type is annotated, it must be Tuple[bool, Any]"
|
||||
@@ -451,9 +435,9 @@ class Task(BaseModel):
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else (
|
||||
pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
else pydantic_output.model_dump_json()
|
||||
if pydantic_output
|
||||
else result
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output))
|
||||
|
||||
@@ -1283,109 +1283,3 @@ def test_interpolate_valid_types():
|
||||
assert parsed["optional"] is None
|
||||
assert parsed["nested"]["flag"] is True
|
||||
assert parsed["nested"]["empty"] is None
|
||||
|
||||
|
||||
def test_guardrail_with_new_style_annotations():
|
||||
"""Test that guardrails with new-style type annotations work correctly."""
|
||||
|
||||
# Define a guardrail with new-style annotation
|
||||
def guardrail(result: TaskOutput) -> tuple[bool, str]:
|
||||
return (True, result.raw.upper())
|
||||
|
||||
agent = MagicMock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "TEST RESULT"
|
||||
|
||||
|
||||
def test_guardrail_with_specific_return_type():
|
||||
"""Test that guardrails with specific return types work correctly."""
|
||||
|
||||
# Define a guardrail with specific return type
|
||||
def guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
|
||||
if "error" in result.raw.lower():
|
||||
return (False, "Contains error")
|
||||
return (True, result)
|
||||
|
||||
agent = MagicMock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.return_value = "success result"
|
||||
agent.crew = None
|
||||
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "success result"
|
||||
|
||||
|
||||
def test_guardrail_with_positional_and_default_args():
|
||||
"""Test that guardrails with positional and default arguments work correctly."""
|
||||
|
||||
# Define a guardrail with a positional argument and a default argument
|
||||
def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]:
|
||||
return (True, result.raw.upper())
|
||||
|
||||
agent = MagicMock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
# This should now work with the updated validator
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "TEST RESULT"
|
||||
|
||||
|
||||
def test_guardrail_with_multiple_positional_args():
|
||||
"""Test that guardrails with multiple positional arguments are rejected."""
|
||||
|
||||
# Define a guardrail with multiple positional arguments
|
||||
def guardrail(result: TaskOutput, another_required_arg) -> tuple[bool, str]:
|
||||
return (True, result.raw.upper())
|
||||
|
||||
agent = MagicMock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
# This should raise a ValueError because guardrail must accept exactly one positional parameter
|
||||
with pytest.raises(ValueError) as excinfo:
|
||||
Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
assert "Guardrail function must accept exactly one parameter" in str(excinfo.value)
|
||||
|
||||
|
||||
def test_guardrail_with_positional_and_default_args():
|
||||
"""Validate that the guardrail function has the correct signature and behavior.
|
||||
|
||||
While type hints provide static checking, this validator ensures runtime safety by:
|
||||
1. Verifying the function accepts exactly one required parameter (the TaskOutput)
|
||||
(additional parameters with default values are allowed)
|
||||
2. Checking return type annotations match Tuple[bool, Any] or tuple[bool, Any] if present
|
||||
3. Providing clear, immediate error messages for debugging
|
||||
"""
|
||||
|
||||
# Define a guardrail with a positional argument and a default argument
|
||||
def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]:
|
||||
return (True, result.raw.upper())
|
||||
|
||||
agent = MagicMock()
|
||||
agent.role = "test_agent"
|
||||
agent.execute_task.return_value = "test result"
|
||||
agent.crew = None
|
||||
|
||||
# This should now work with the updated validator
|
||||
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
|
||||
|
||||
result = task.execute_sync(agent=agent)
|
||||
assert isinstance(result, TaskOutput)
|
||||
assert result.raw == "TEST RESULT"
|
||||
|
||||
Reference in New Issue
Block a user