mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-11 00:58:30 +00:00
Merge branch 'main' into feat/improve-crew-flow-kickoff
This commit is contained in:
@@ -114,7 +114,6 @@ class Agent(BaseAgent):
|
|||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def post_init_setup(self):
|
def post_init_setup(self):
|
||||||
self._set_knowledge()
|
|
||||||
self.agent_ops_agent_name = self.role
|
self.agent_ops_agent_name = self.role
|
||||||
|
|
||||||
self.llm = create_llm(self.llm)
|
self.llm = create_llm(self.llm)
|
||||||
@@ -134,8 +133,11 @@ class Agent(BaseAgent):
|
|||||||
self.cache_handler = CacheHandler()
|
self.cache_handler = CacheHandler()
|
||||||
self.set_cache_handler(self.cache_handler)
|
self.set_cache_handler(self.cache_handler)
|
||||||
|
|
||||||
def _set_knowledge(self):
|
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||||
try:
|
try:
|
||||||
|
if self.embedder is None and crew_embedder:
|
||||||
|
self.embedder = crew_embedder
|
||||||
|
|
||||||
if self.knowledge_sources:
|
if self.knowledge_sources:
|
||||||
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
|
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
|
||||||
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
|
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"
|
||||||
|
|||||||
@@ -351,3 +351,6 @@ class BaseAgent(ABC, BaseModel):
|
|||||||
if not self._rpm_controller:
|
if not self._rpm_controller:
|
||||||
self._rpm_controller = rpm_controller
|
self._rpm_controller = rpm_controller
|
||||||
self.create_agent_executor()
|
self.create_agent_executor()
|
||||||
|
|
||||||
|
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
|
||||||
|
pass
|
||||||
|
|||||||
@@ -124,14 +124,15 @@ class CrewAgentParser:
|
|||||||
)
|
)
|
||||||
|
|
||||||
def _extract_thought(self, text: str) -> str:
|
def _extract_thought(self, text: str) -> str:
|
||||||
regex = r"(.*?)(?:\n\nAction|\n\nFinal Answer)"
|
thought_index = text.find("\n\nAction")
|
||||||
thought_match = re.search(regex, text, re.DOTALL)
|
if thought_index == -1:
|
||||||
if thought_match:
|
thought_index = text.find("\n\nFinal Answer")
|
||||||
thought = thought_match.group(1).strip()
|
if thought_index == -1:
|
||||||
# Remove any triple backticks from the thought string
|
return ""
|
||||||
thought = thought.replace("```", "").strip()
|
thought = text[:thought_index].strip()
|
||||||
return thought
|
# Remove any triple backticks from the thought string
|
||||||
return ""
|
thought = thought.replace("```", "").strip()
|
||||||
|
return thought
|
||||||
|
|
||||||
def _clean_action(self, text: str) -> str:
|
def _clean_action(self, text: str) -> str:
|
||||||
"""Clean action string by removing non-essential formatting characters."""
|
"""Clean action string by removing non-essential formatting characters."""
|
||||||
|
|||||||
@@ -216,10 +216,43 @@ MODELS = {
|
|||||||
"watsonx/ibm/granite-3-8b-instruct",
|
"watsonx/ibm/granite-3-8b-instruct",
|
||||||
],
|
],
|
||||||
"bedrock": [
|
"bedrock": [
|
||||||
|
"bedrock/us.amazon.nova-pro-v1:0",
|
||||||
|
"bedrock/us.amazon.nova-micro-v1:0",
|
||||||
|
"bedrock/us.amazon.nova-lite-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
|
||||||
|
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
|
||||||
|
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
|
||||||
|
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
|
||||||
|
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
|
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
|
"bedrock/amazon.nova-pro-v1:0",
|
||||||
|
"bedrock/amazon.nova-micro-v1:0",
|
||||||
|
"bedrock/amazon.nova-lite-v1:0",
|
||||||
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
|
||||||
|
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
|
||||||
|
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
|
||||||
|
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
|
||||||
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
|
||||||
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
|
||||||
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
"bedrock/anthropic.claude-3-opus-20240229-v1:0",
|
||||||
|
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
|
||||||
"bedrock/anthropic.claude-v2:1",
|
"bedrock/anthropic.claude-v2:1",
|
||||||
"bedrock/anthropic.claude-v2",
|
"bedrock/anthropic.claude-v2",
|
||||||
"bedrock/anthropic.claude-instant-v1",
|
"bedrock/anthropic.claude-instant-v1",
|
||||||
@@ -234,8 +267,6 @@ MODELS = {
|
|||||||
"bedrock/ai21.j2-mid-v1",
|
"bedrock/ai21.j2-mid-v1",
|
||||||
"bedrock/ai21.j2-ultra-v1",
|
"bedrock/ai21.j2-ultra-v1",
|
||||||
"bedrock/ai21.jamba-instruct-v1:0",
|
"bedrock/ai21.jamba-instruct-v1:0",
|
||||||
"bedrock/meta.llama2-13b-chat-v1",
|
|
||||||
"bedrock/meta.llama2-70b-chat-v1",
|
|
||||||
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
"bedrock/mistral.mistral-7b-instruct-v0:2",
|
||||||
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
"bedrock/mistral.mixtral-8x7b-instruct-v0:1",
|
||||||
],
|
],
|
||||||
|
|||||||
@@ -600,6 +600,7 @@ class Crew(BaseModel):
|
|||||||
agent.i18n = i18n
|
agent.i18n = i18n
|
||||||
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
|
||||||
agent.crew = self # type: ignore[attr-defined]
|
agent.crew = self # type: ignore[attr-defined]
|
||||||
|
agent.set_knowledge(crew_embedder=self.embedder)
|
||||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||||
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||||
|
|||||||
@@ -894,35 +894,45 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
|||||||
Notes
|
Notes
|
||||||
-----
|
-----
|
||||||
- Routers are executed sequentially to maintain flow control
|
- Routers are executed sequentially to maintain flow control
|
||||||
- Each router's result becomes the new trigger_method
|
- Each router's result becomes a new trigger_method
|
||||||
- Normal listeners are executed in parallel for efficiency
|
- Normal listeners are executed in parallel for efficiency
|
||||||
- Listeners can receive the trigger method's result as a parameter
|
- Listeners can receive the trigger method's result as a parameter
|
||||||
"""
|
"""
|
||||||
# First, handle routers repeatedly until no router triggers anymore
|
# First, handle routers repeatedly until no router triggers anymore
|
||||||
|
router_results = []
|
||||||
|
current_trigger = trigger_method
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
routers_triggered = self._find_triggered_methods(
|
routers_triggered = self._find_triggered_methods(
|
||||||
trigger_method, router_only=True
|
current_trigger, router_only=True
|
||||||
)
|
)
|
||||||
if not routers_triggered:
|
if not routers_triggered:
|
||||||
break
|
break
|
||||||
|
|
||||||
for router_name in routers_triggered:
|
for router_name in routers_triggered:
|
||||||
await self._execute_single_listener(router_name, result)
|
await self._execute_single_listener(router_name, result)
|
||||||
# After executing router, the router's result is the path
|
# After executing router, the router's result is the path
|
||||||
# The last router executed sets the trigger_method
|
router_result = self._method_outputs[-1]
|
||||||
# The router result is the last element in self._method_outputs
|
if router_result: # Only add non-None results
|
||||||
trigger_method = self._method_outputs[-1]
|
router_results.append(router_result)
|
||||||
|
current_trigger = (
|
||||||
|
router_result # Update for next iteration of router chain
|
||||||
|
)
|
||||||
|
|
||||||
# Now that no more routers are triggered by current trigger_method,
|
# Now execute normal listeners for all router results and the original trigger
|
||||||
# execute normal listeners
|
all_triggers = [trigger_method] + router_results
|
||||||
listeners_triggered = self._find_triggered_methods(
|
|
||||||
trigger_method, router_only=False
|
for current_trigger in all_triggers:
|
||||||
)
|
if current_trigger: # Skip None results
|
||||||
if listeners_triggered:
|
listeners_triggered = self._find_triggered_methods(
|
||||||
tasks = [
|
current_trigger, router_only=False
|
||||||
self._execute_single_listener(listener_name, result)
|
)
|
||||||
for listener_name in listeners_triggered
|
if listeners_triggered:
|
||||||
]
|
tasks = [
|
||||||
await asyncio.gather(*tasks)
|
self._execute_single_listener(listener_name, result)
|
||||||
|
for listener_name in listeners_triggered
|
||||||
|
]
|
||||||
|
await asyncio.gather(*tasks)
|
||||||
|
|
||||||
def _find_triggered_methods(
|
def _find_triggered_methods(
|
||||||
self, trigger_method: str, router_only: bool
|
self, trigger_method: str, router_only: bool
|
||||||
|
|||||||
@@ -4,7 +4,7 @@ SQLite-based implementation of flow state persistence.
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import sqlite3
|
import sqlite3
|
||||||
from datetime import datetime
|
from datetime import datetime, timezone
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional, Union
|
from typing import Any, Dict, Optional, Union
|
||||||
|
|
||||||
@@ -34,6 +34,7 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
ValueError: If db_path is invalid
|
ValueError: If db_path is invalid
|
||||||
"""
|
"""
|
||||||
from crewai.utilities.paths import db_storage_path
|
from crewai.utilities.paths import db_storage_path
|
||||||
|
|
||||||
# Get path from argument or default location
|
# Get path from argument or default location
|
||||||
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
path = db_path or str(Path(db_storage_path()) / "flow_states.db")
|
||||||
|
|
||||||
@@ -46,7 +47,8 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
def init_db(self) -> None:
|
def init_db(self) -> None:
|
||||||
"""Create the necessary tables if they don't exist."""
|
"""Create the necessary tables if they don't exist."""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE TABLE IF NOT EXISTS flow_states (
|
CREATE TABLE IF NOT EXISTS flow_states (
|
||||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
flow_uuid TEXT NOT NULL,
|
flow_uuid TEXT NOT NULL,
|
||||||
@@ -54,12 +56,15 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
timestamp DATETIME NOT NULL,
|
timestamp DATETIME NOT NULL,
|
||||||
state_json TEXT NOT NULL
|
state_json TEXT NOT NULL
|
||||||
)
|
)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
# Add index for faster UUID lookups
|
# Add index for faster UUID lookups
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||||
ON flow_states(flow_uuid)
|
ON flow_states(flow_uuid)
|
||||||
""")
|
"""
|
||||||
|
)
|
||||||
|
|
||||||
def save_state(
|
def save_state(
|
||||||
self,
|
self,
|
||||||
@@ -85,19 +90,22 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
)
|
)
|
||||||
|
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
conn.execute("""
|
conn.execute(
|
||||||
|
"""
|
||||||
INSERT INTO flow_states (
|
INSERT INTO flow_states (
|
||||||
flow_uuid,
|
flow_uuid,
|
||||||
method_name,
|
method_name,
|
||||||
timestamp,
|
timestamp,
|
||||||
state_json
|
state_json
|
||||||
) VALUES (?, ?, ?, ?)
|
) VALUES (?, ?, ?, ?)
|
||||||
""", (
|
""",
|
||||||
flow_uuid,
|
(
|
||||||
method_name,
|
flow_uuid,
|
||||||
datetime.utcnow().isoformat(),
|
method_name,
|
||||||
json.dumps(state_dict),
|
datetime.now(timezone.utc).isoformat(),
|
||||||
))
|
json.dumps(state_dict),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
|
||||||
"""Load the most recent state for a given flow UUID.
|
"""Load the most recent state for a given flow UUID.
|
||||||
@@ -109,13 +117,16 @@ class SQLiteFlowPersistence(FlowPersistence):
|
|||||||
The most recent state as a dictionary, or None if no state exists
|
The most recent state as a dictionary, or None if no state exists
|
||||||
"""
|
"""
|
||||||
with sqlite3.connect(self.db_path) as conn:
|
with sqlite3.connect(self.db_path) as conn:
|
||||||
cursor = conn.execute("""
|
cursor = conn.execute(
|
||||||
|
"""
|
||||||
SELECT state_json
|
SELECT state_json
|
||||||
FROM flow_states
|
FROM flow_states
|
||||||
WHERE flow_uuid = ?
|
WHERE flow_uuid = ?
|
||||||
ORDER BY id DESC
|
ORDER BY id DESC
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
""", (flow_uuid,))
|
""",
|
||||||
|
(flow_uuid,),
|
||||||
|
)
|
||||||
row = cursor.fetchone()
|
row = cursor.fetchone()
|
||||||
|
|
||||||
if row:
|
if row:
|
||||||
|
|||||||
@@ -16,7 +16,8 @@ Example
|
|||||||
import ast
|
import ast
|
||||||
import inspect
|
import inspect
|
||||||
import textwrap
|
import textwrap
|
||||||
from typing import Any, Dict, List, Optional, Set, Union
|
from collections import defaultdict, deque
|
||||||
|
from typing import Any, Deque, Dict, List, Optional, Set, Union
|
||||||
|
|
||||||
|
|
||||||
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
def get_possible_return_constants(function: Any) -> Optional[List[str]]:
|
||||||
@@ -118,7 +119,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
|||||||
- Processes router paths separately
|
- Processes router paths separately
|
||||||
"""
|
"""
|
||||||
levels: Dict[str, int] = {}
|
levels: Dict[str, int] = {}
|
||||||
queue: List[str] = []
|
queue: Deque[str] = deque()
|
||||||
visited: Set[str] = set()
|
visited: Set[str] = set()
|
||||||
pending_and_listeners: Dict[str, Set[str]] = {}
|
pending_and_listeners: Dict[str, Set[str]] = {}
|
||||||
|
|
||||||
@@ -128,28 +129,35 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
|||||||
levels[method_name] = 0
|
levels[method_name] = 0
|
||||||
queue.append(method_name)
|
queue.append(method_name)
|
||||||
|
|
||||||
|
# Precompute listener dependencies
|
||||||
|
or_listeners = defaultdict(list)
|
||||||
|
and_listeners = defaultdict(set)
|
||||||
|
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
||||||
|
if condition_type == "OR":
|
||||||
|
for method in trigger_methods:
|
||||||
|
or_listeners[method].append(listener_name)
|
||||||
|
elif condition_type == "AND":
|
||||||
|
and_listeners[listener_name] = set(trigger_methods)
|
||||||
|
|
||||||
# Breadth-first traversal to assign levels
|
# Breadth-first traversal to assign levels
|
||||||
while queue:
|
while queue:
|
||||||
current = queue.pop(0)
|
current = queue.popleft()
|
||||||
current_level = levels[current]
|
current_level = levels[current]
|
||||||
visited.add(current)
|
visited.add(current)
|
||||||
|
|
||||||
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
|
for listener_name in or_listeners[current]:
|
||||||
if condition_type == "OR":
|
if listener_name not in levels or levels[listener_name] > current_level + 1:
|
||||||
if current in trigger_methods:
|
levels[listener_name] = current_level + 1
|
||||||
if (
|
if listener_name not in visited:
|
||||||
listener_name not in levels
|
queue.append(listener_name)
|
||||||
or levels[listener_name] > current_level + 1
|
|
||||||
):
|
for listener_name, required_methods in and_listeners.items():
|
||||||
levels[listener_name] = current_level + 1
|
if current in required_methods:
|
||||||
if listener_name not in visited:
|
|
||||||
queue.append(listener_name)
|
|
||||||
elif condition_type == "AND":
|
|
||||||
if listener_name not in pending_and_listeners:
|
if listener_name not in pending_and_listeners:
|
||||||
pending_and_listeners[listener_name] = set()
|
pending_and_listeners[listener_name] = set()
|
||||||
if current in trigger_methods:
|
pending_and_listeners[listener_name].add(current)
|
||||||
pending_and_listeners[listener_name].add(current)
|
|
||||||
if set(trigger_methods) == pending_and_listeners[listener_name]:
|
if required_methods == pending_and_listeners[listener_name]:
|
||||||
if (
|
if (
|
||||||
listener_name not in levels
|
listener_name not in levels
|
||||||
or levels[listener_name] > current_level + 1
|
or levels[listener_name] > current_level + 1
|
||||||
@@ -159,22 +167,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
|
|||||||
queue.append(listener_name)
|
queue.append(listener_name)
|
||||||
|
|
||||||
# Handle router connections
|
# Handle router connections
|
||||||
if current in flow._routers:
|
process_router_paths(flow, current, current_level, levels, queue)
|
||||||
router_method_name = current
|
|
||||||
paths = flow._router_paths.get(router_method_name, [])
|
|
||||||
for path in paths:
|
|
||||||
for listener_name, (
|
|
||||||
condition_type,
|
|
||||||
trigger_methods,
|
|
||||||
) in flow._listeners.items():
|
|
||||||
if path in trigger_methods:
|
|
||||||
if (
|
|
||||||
listener_name not in levels
|
|
||||||
or levels[listener_name] > current_level + 1
|
|
||||||
):
|
|
||||||
levels[listener_name] = current_level + 1
|
|
||||||
if listener_name not in visited:
|
|
||||||
queue.append(listener_name)
|
|
||||||
|
|
||||||
return levels
|
return levels
|
||||||
|
|
||||||
@@ -227,10 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
|
|||||||
|
|
||||||
|
|
||||||
def dfs_ancestors(
|
def dfs_ancestors(
|
||||||
node: str,
|
node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
|
||||||
ancestors: Dict[str, Set[str]],
|
|
||||||
visited: Set[str],
|
|
||||||
flow: Any
|
|
||||||
) -> None:
|
) -> None:
|
||||||
"""
|
"""
|
||||||
Perform depth-first search to build ancestor relationships.
|
Perform depth-first search to build ancestor relationships.
|
||||||
@@ -274,7 +264,9 @@ def dfs_ancestors(
|
|||||||
dfs_ancestors(listener_name, ancestors, visited, flow)
|
dfs_ancestors(listener_name, ancestors, visited, flow)
|
||||||
|
|
||||||
|
|
||||||
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool:
|
def is_ancestor(
|
||||||
|
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
|
||||||
|
) -> bool:
|
||||||
"""
|
"""
|
||||||
Check if one node is an ancestor of another.
|
Check if one node is an ancestor of another.
|
||||||
|
|
||||||
@@ -339,7 +331,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
|
|||||||
return parent_children
|
return parent_children
|
||||||
|
|
||||||
|
|
||||||
def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int:
|
def get_child_index(
|
||||||
|
parent: str, child: str, parent_children: Dict[str, List[str]]
|
||||||
|
) -> int:
|
||||||
"""
|
"""
|
||||||
Get the index of a child node in its parent's sorted children list.
|
Get the index of a child node in its parent's sorted children list.
|
||||||
|
|
||||||
@@ -360,3 +354,23 @@ def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str
|
|||||||
children = parent_children.get(parent, [])
|
children = parent_children.get(parent, [])
|
||||||
children.sort()
|
children.sort()
|
||||||
return children.index(child)
|
return children.index(child)
|
||||||
|
|
||||||
|
|
||||||
|
def process_router_paths(flow, current, current_level, levels, queue):
|
||||||
|
"""
|
||||||
|
Handle the router connections for the current node.
|
||||||
|
"""
|
||||||
|
if current in flow._routers:
|
||||||
|
paths = flow._router_paths.get(current, [])
|
||||||
|
for path in paths:
|
||||||
|
for listener_name, (
|
||||||
|
condition_type,
|
||||||
|
trigger_methods,
|
||||||
|
) in flow._listeners.items():
|
||||||
|
if path in trigger_methods:
|
||||||
|
if (
|
||||||
|
listener_name not in levels
|
||||||
|
or levels[listener_name] > current_level + 1
|
||||||
|
):
|
||||||
|
levels[listener_name] = current_level + 1
|
||||||
|
queue.append(listener_name)
|
||||||
|
|||||||
@@ -64,6 +64,7 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
|||||||
"gpt-4-turbo": 128000,
|
"gpt-4-turbo": 128000,
|
||||||
"o1-preview": 128000,
|
"o1-preview": 128000,
|
||||||
"o1-mini": 128000,
|
"o1-mini": 128000,
|
||||||
|
"o3-mini": 200000, # Based on official o3-mini specifications
|
||||||
# gemini
|
# gemini
|
||||||
"gemini-2.0-flash": 1048576,
|
"gemini-2.0-flash": 1048576,
|
||||||
"gemini-1.5-pro": 2097152,
|
"gemini-1.5-pro": 2097152,
|
||||||
@@ -485,10 +486,23 @@ class LLM:
|
|||||||
"""
|
"""
|
||||||
Returns the context window size, using 75% of the maximum to avoid
|
Returns the context window size, using 75% of the maximum to avoid
|
||||||
cutting off messages mid-thread.
|
cutting off messages mid-thread.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
|
||||||
"""
|
"""
|
||||||
if self.context_window_size != 0:
|
if self.context_window_size != 0:
|
||||||
return self.context_window_size
|
return self.context_window_size
|
||||||
|
|
||||||
|
MIN_CONTEXT = 1024
|
||||||
|
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
|
||||||
|
|
||||||
|
# Validate all context window sizes
|
||||||
|
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||||
|
if value < MIN_CONTEXT or value > MAX_CONTEXT:
|
||||||
|
raise ValueError(
|
||||||
|
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
|
||||||
|
)
|
||||||
|
|
||||||
self.context_window_size = int(
|
self.context_window_size = int(
|
||||||
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -44,6 +44,7 @@ def create_llm(
|
|||||||
# Extract attributes with explicit types
|
# Extract attributes with explicit types
|
||||||
model = (
|
model = (
|
||||||
getattr(llm_value, "model_name", None)
|
getattr(llm_value, "model_name", None)
|
||||||
|
or getattr(llm_value, "model", None)
|
||||||
or getattr(llm_value, "deployment_name", None)
|
or getattr(llm_value, "deployment_name", None)
|
||||||
or str(llm_value)
|
or str(llm_value)
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -30,8 +30,14 @@ class TokenCalcHandler(CustomLogger):
|
|||||||
if hasattr(usage, "prompt_tokens"):
|
if hasattr(usage, "prompt_tokens"):
|
||||||
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
|
||||||
if hasattr(usage, "completion_tokens"):
|
if hasattr(usage, "completion_tokens"):
|
||||||
self.token_cost_process.sum_completion_tokens(usage.completion_tokens)
|
self.token_cost_process.sum_completion_tokens(
|
||||||
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details:
|
usage.completion_tokens
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
hasattr(usage, "prompt_tokens_details")
|
||||||
|
and usage.prompt_tokens_details
|
||||||
|
and usage.prompt_tokens_details.cached_tokens
|
||||||
|
):
|
||||||
self.token_cost_process.sum_cached_prompt_tokens(
|
self.token_cost_process.sum_cached_prompt_tokens(
|
||||||
usage.prompt_tokens_details.cached_tokens
|
usage.prompt_tokens_details.cached_tokens
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -654,3 +654,104 @@ def test_flow_plotting():
|
|||||||
assert isinstance(received_events[0], FlowPlotEvent)
|
assert isinstance(received_events[0], FlowPlotEvent)
|
||||||
assert received_events[0].flow_name == "StatelessFlow"
|
assert received_events[0].flow_name == "StatelessFlow"
|
||||||
assert isinstance(received_events[0].timestamp, datetime)
|
assert isinstance(received_events[0].timestamp, datetime)
|
||||||
|
|
||||||
|
|
||||||
|
def test_multiple_routers_from_same_trigger():
|
||||||
|
"""Test that multiple routers triggered by the same method all activate their listeners."""
|
||||||
|
execution_order = []
|
||||||
|
|
||||||
|
class MultiRouterFlow(Flow):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
# Set diagnosed conditions to trigger all routers
|
||||||
|
self.state["diagnosed_conditions"] = "DHA" # Contains D, H, and A
|
||||||
|
|
||||||
|
@start()
|
||||||
|
def scan_medical(self):
|
||||||
|
execution_order.append("scan_medical")
|
||||||
|
return "scan_complete"
|
||||||
|
|
||||||
|
@router(scan_medical)
|
||||||
|
def diagnose_conditions(self):
|
||||||
|
execution_order.append("diagnose_conditions")
|
||||||
|
return "diagnosis_complete"
|
||||||
|
|
||||||
|
@router(diagnose_conditions)
|
||||||
|
def diabetes_router(self):
|
||||||
|
execution_order.append("diabetes_router")
|
||||||
|
if "D" in self.state["diagnosed_conditions"]:
|
||||||
|
return "diabetes"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@listen("diabetes")
|
||||||
|
def diabetes_analysis(self):
|
||||||
|
execution_order.append("diabetes_analysis")
|
||||||
|
return "diabetes_analysis_complete"
|
||||||
|
|
||||||
|
@router(diagnose_conditions)
|
||||||
|
def hypertension_router(self):
|
||||||
|
execution_order.append("hypertension_router")
|
||||||
|
if "H" in self.state["diagnosed_conditions"]:
|
||||||
|
return "hypertension"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@listen("hypertension")
|
||||||
|
def hypertension_analysis(self):
|
||||||
|
execution_order.append("hypertension_analysis")
|
||||||
|
return "hypertension_analysis_complete"
|
||||||
|
|
||||||
|
@router(diagnose_conditions)
|
||||||
|
def anemia_router(self):
|
||||||
|
execution_order.append("anemia_router")
|
||||||
|
if "A" in self.state["diagnosed_conditions"]:
|
||||||
|
return "anemia"
|
||||||
|
return None
|
||||||
|
|
||||||
|
@listen("anemia")
|
||||||
|
def anemia_analysis(self):
|
||||||
|
execution_order.append("anemia_analysis")
|
||||||
|
return "anemia_analysis_complete"
|
||||||
|
|
||||||
|
flow = MultiRouterFlow()
|
||||||
|
flow.kickoff()
|
||||||
|
|
||||||
|
# Verify all methods were called
|
||||||
|
assert "scan_medical" in execution_order
|
||||||
|
assert "diagnose_conditions" in execution_order
|
||||||
|
|
||||||
|
# Verify all routers were called
|
||||||
|
assert "diabetes_router" in execution_order
|
||||||
|
assert "hypertension_router" in execution_order
|
||||||
|
assert "anemia_router" in execution_order
|
||||||
|
|
||||||
|
# Verify all listeners were called - this is the key test for the fix
|
||||||
|
assert "diabetes_analysis" in execution_order
|
||||||
|
assert "hypertension_analysis" in execution_order
|
||||||
|
assert "anemia_analysis" in execution_order
|
||||||
|
|
||||||
|
# Verify execution order constraints
|
||||||
|
assert execution_order.index("diagnose_conditions") > execution_order.index(
|
||||||
|
"scan_medical"
|
||||||
|
)
|
||||||
|
|
||||||
|
# All routers should execute after diagnose_conditions
|
||||||
|
assert execution_order.index("diabetes_router") > execution_order.index(
|
||||||
|
"diagnose_conditions"
|
||||||
|
)
|
||||||
|
assert execution_order.index("hypertension_router") > execution_order.index(
|
||||||
|
"diagnose_conditions"
|
||||||
|
)
|
||||||
|
assert execution_order.index("anemia_router") > execution_order.index(
|
||||||
|
"diagnose_conditions"
|
||||||
|
)
|
||||||
|
|
||||||
|
# All analyses should execute after their respective routers
|
||||||
|
assert execution_order.index("diabetes_analysis") > execution_order.index(
|
||||||
|
"diabetes_router"
|
||||||
|
)
|
||||||
|
assert execution_order.index("hypertension_analysis") > execution_order.index(
|
||||||
|
"hypertension_router"
|
||||||
|
)
|
||||||
|
assert execution_order.index("anemia_analysis") > execution_order.index(
|
||||||
|
"anemia_router"
|
||||||
|
)
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import pytest
|
|||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
|
||||||
from crewai.llm import LLM
|
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
|
||||||
from crewai.utilities.events import crewai_event_bus
|
from crewai.utilities.events import crewai_event_bus
|
||||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||||
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||||
@@ -285,6 +285,23 @@ def test_o3_mini_reasoning_effort_medium():
|
|||||||
assert isinstance(result, str)
|
assert isinstance(result, str)
|
||||||
assert "Paris" in result
|
assert "Paris" in result
|
||||||
|
|
||||||
|
def test_context_window_validation():
|
||||||
|
"""Test that context window validation works correctly."""
|
||||||
|
# Test valid window size
|
||||||
|
llm = LLM(model="o3-mini")
|
||||||
|
assert llm.get_context_window_size() == int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||||
|
|
||||||
|
# Test invalid window size
|
||||||
|
with pytest.raises(ValueError) as excinfo:
|
||||||
|
with patch.dict(
|
||||||
|
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
|
||||||
|
{"test-model": 500}, # Below minimum
|
||||||
|
clear=True,
|
||||||
|
):
|
||||||
|
llm = LLM(model="test-model")
|
||||||
|
llm.get_context_window_size()
|
||||||
|
assert "must be between 1024 and 2097152" in str(excinfo.value)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
|
|||||||
Reference in New Issue
Block a user