Compare commits

..

5 Commits

Author SHA1 Message Date
Brandon Hancock (bhancock_ai)
3d148090d9 Merge branch 'main' into bugfix/fix-type-error-in-token-counter 2025-02-26 13:25:02 -05:00
Brandon Hancock (bhancock_ai)
8fedbe49cb Add support for python 3.10 (#2230) 2025-02-26 13:24:31 -05:00
Lorenze Jay
82cfd353b3 Merge branch 'main' into bugfix/fix-type-error-in-token-counter 2025-02-26 10:01:19 -08:00
Lorenze Jay
1e8ee247ca feat: Enhance agent knowledge setup with optional crew embedder (#2232)
- Modify `Agent` class to add `set_knowledge` method
- Allow setting embedder from crew-level configuration
- Remove `_set_knowledge` method from initialization
- Update `Crew` class to set agent knowledge during agent setup
- Add default implementation in `BaseAgent` for compatibility
2025-02-26 12:10:43 -05:00
Brandon Hancock
0903bbeca2 Fix type issue 2025-02-25 12:18:52 -05:00
7 changed files with 47 additions and 146 deletions

View File

@@ -114,7 +114,6 @@ class Agent(BaseAgent):
@model_validator(mode="after")
def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm)
@@ -134,8 +133,11 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler)
def _set_knowledge(self):
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
if self.knowledge_sources:
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"

View File

@@ -351,3 +351,6 @@ class BaseAgent(ABC, BaseModel):
if not self._rpm_controller:
self._rpm_controller = rpm_controller
self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
pass

View File

@@ -600,6 +600,7 @@ class Crew(BaseModel):
agent.i18n = i18n
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
agent.crew = self # type: ignore[attr-defined]
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"

View File

@@ -4,7 +4,7 @@ SQLite-based implementation of flow state persistence.
import json
import sqlite3
from datetime import datetime
from datetime import datetime, timezone
from pathlib import Path
from typing import Any, Dict, Optional, Union
@@ -34,6 +34,7 @@ class SQLiteFlowPersistence(FlowPersistence):
ValueError: If db_path is invalid
"""
from crewai.utilities.paths import db_storage_path
# Get path from argument or default location
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
@@ -46,7 +47,8 @@ class SQLiteFlowPersistence(FlowPersistence):
def init_db(self) -> None:
"""Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
conn.execute(
"""
CREATE TABLE IF NOT EXISTS flow_states (
id INTEGER PRIMARY KEY AUTOINCREMENT,
flow_uuid TEXT NOT NULL,
@@ -54,12 +56,15 @@ class SQLiteFlowPersistence(FlowPersistence):
timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL
)
""")
"""
)
# Add index for faster UUID lookups
conn.execute("""
conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid)
""")
"""
)
def save_state(
self,
@@ -85,19 +90,22 @@ class SQLiteFlowPersistence(FlowPersistence):
)
with sqlite3.connect(self.db_path) as conn:
conn.execute("""
conn.execute(
"""
INSERT INTO flow_states (
flow_uuid,
method_name,
timestamp,
state_json
) VALUES (?, ?, ?, ?)
""", (
flow_uuid,
method_name,
datetime.utcnow().isoformat(),
json.dumps(state_dict),
))
""",
(
flow_uuid,
method_name,
datetime.now(timezone.utc).isoformat(),
json.dumps(state_dict),
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID.
@@ -109,13 +117,16 @@ class SQLiteFlowPersistence(FlowPersistence):
The most recent state as a dictionary, or None if no state exists
"""
with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute("""
cursor = conn.execute(
"""
SELECT state_json
FROM flow_states
WHERE flow_uuid = ?
ORDER BY id DESC
LIMIT 1
""", (flow_uuid,))
""",
(flow_uuid,),
)
row = cursor.fetchone()
if row:

View File

@@ -19,8 +19,6 @@ from typing import (
Tuple,
Type,
Union,
get_args,
get_origin,
)
from pydantic import (
@@ -174,29 +172,15 @@ class Task(BaseModel):
"""
if v is not None:
sig = inspect.signature(v)
positional_args = [
param
for param in sig.parameters.values()
if param.default is inspect.Parameter.empty
]
if len(positional_args) != 1:
if len(sig.parameters) != 1:
raise ValueError("Guardrail function must accept exactly one parameter")
# Check return annotation if present, but don't require it
return_annotation = sig.return_annotation
if return_annotation != inspect.Signature.empty:
return_annotation_args = get_args(return_annotation)
if not (
get_origin(return_annotation) is tuple
and len(return_annotation_args) == 2
and return_annotation_args[0] is bool
and (
return_annotation_args[1] is Any
or return_annotation_args[1] is str
or return_annotation_args[1] is TaskOutput
or return_annotation_args[1] == Union[str, TaskOutput]
)
return_annotation == Tuple[bool, Any]
or str(return_annotation) == "Tuple[bool, Any]"
):
raise ValueError(
"If return type is annotated, it must be Tuple[bool, Any]"
@@ -451,9 +435,9 @@ class Task(BaseModel):
content = (
json_output
if json_output
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
else pydantic_output.model_dump_json()
if pydantic_output
else result
)
self._save_file(content)
crewai_event_bus.emit(self, TaskCompletedEvent(output=task_output))

View File

@@ -30,8 +30,14 @@ class TokenCalcHandler(CustomLogger):
if hasattr(usage, "prompt_tokens"):
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
if hasattr(usage, "completion_tokens"):
self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
self.token_cost_process.sum_completion_tokens(
usage.completion_tokens
)
if (
hasattr(usage, "prompt_tokens_details")
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
self.token_cost_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens
)

View File

@@ -1283,109 +1283,3 @@ def test_interpolate_valid_types():
assert parsed["optional"] is None
assert parsed["nested"]["flag"] is True
assert parsed["nested"]["empty"] is None
def test_guardrail_with_new_style_annotations():
"""Test that guardrails with new-style type annotations work correctly."""
# Define a guardrail with new-style annotation
def guardrail(result: TaskOutput) -> tuple[bool, str]:
return (True, result.raw.upper())
agent = MagicMock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_guardrail_with_specific_return_type():
"""Test that guardrails with specific return types work correctly."""
# Define a guardrail with specific return type
def guardrail(result: TaskOutput) -> tuple[bool, TaskOutput]:
if "error" in result.raw.lower():
return (False, "Contains error")
return (True, result)
agent = MagicMock()
agent.role = "test_agent"
agent.execute_task.return_value = "success result"
agent.crew = None
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "success result"
def test_guardrail_with_positional_and_default_args():
"""Test that guardrails with positional and default arguments work correctly."""
# Define a guardrail with a positional argument and a default argument
def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]:
return (True, result.raw.upper())
agent = MagicMock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
# This should now work with the updated validator
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"
def test_guardrail_with_multiple_positional_args():
"""Test that guardrails with multiple positional arguments are rejected."""
# Define a guardrail with multiple positional arguments
def guardrail(result: TaskOutput, another_required_arg) -> tuple[bool, str]:
return (True, result.raw.upper())
agent = MagicMock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
# This should raise a ValueError because guardrail must accept exactly one positional parameter
with pytest.raises(ValueError) as excinfo:
Task(description="Test task", expected_output="Output", guardrail=guardrail)
assert "Guardrail function must accept exactly one parameter" in str(excinfo.value)
def test_guardrail_with_positional_and_default_args():
"""Validate that the guardrail function has the correct signature and behavior.
While type hints provide static checking, this validator ensures runtime safety by:
1. Verifying the function accepts exactly one required parameter (the TaskOutput)
(additional parameters with default values are allowed)
2. Checking return type annotations match Tuple[bool, Any] or tuple[bool, Any] if present
3. Providing clear, immediate error messages for debugging
"""
# Define a guardrail with a positional argument and a default argument
def guardrail(result: TaskOutput, optional_arg=None) -> tuple[bool, str]:
return (True, result.raw.upper())
agent = MagicMock()
agent.role = "test_agent"
agent.execute_task.return_value = "test result"
agent.crew = None
# This should now work with the updated validator
task = Task(description="Test task", expected_output="Output", guardrail=guardrail)
result = task.execute_sync(agent=agent)
assert isinstance(result, TaskOutput)
assert result.raw == "TEST RESULT"