Merge branch 'main' into feat/improve-crew-flow-kickoff

This commit is contained in:
Brandon Hancock (bhancock_ai)
2025-02-26 14:52:05 -05:00
committed by GitHub
13 changed files with 297 additions and 85 deletions

View File

@@ -114,7 +114,6 @@ class Agent(BaseAgent):
@model_validator(mode="after") @model_validator(mode="after")
def post_init_setup(self): def post_init_setup(self):
self._set_knowledge()
self.agent_ops_agent_name = self.role self.agent_ops_agent_name = self.role
self.llm = create_llm(self.llm) self.llm = create_llm(self.llm)
@@ -134,8 +133,11 @@ class Agent(BaseAgent):
self.cache_handler = CacheHandler() self.cache_handler = CacheHandler()
self.set_cache_handler(self.cache_handler) self.set_cache_handler(self.cache_handler)
def _set_knowledge(self): def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
try: try:
if self.embedder is None and crew_embedder:
self.embedder = crew_embedder
if self.knowledge_sources: if self.knowledge_sources:
full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)") full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)")
knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}" knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}"

View File

@@ -351,3 +351,6 @@ class BaseAgent(ABC, BaseModel):
if not self._rpm_controller: if not self._rpm_controller:
self._rpm_controller = rpm_controller self._rpm_controller = rpm_controller
self.create_agent_executor() self.create_agent_executor()
def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None):
pass

View File

@@ -124,14 +124,15 @@ class CrewAgentParser:
) )
def _extract_thought(self, text: str) -> str: def _extract_thought(self, text: str) -> str:
regex = r"(.*?)(?:\n\nAction|\n\nFinal Answer)" thought_index = text.find("\n\nAction")
thought_match = re.search(regex, text, re.DOTALL) if thought_index == -1:
if thought_match: thought_index = text.find("\n\nFinal Answer")
thought = thought_match.group(1).strip() if thought_index == -1:
# Remove any triple backticks from the thought string return ""
thought = thought.replace("```", "").strip() thought = text[:thought_index].strip()
return thought # Remove any triple backticks from the thought string
return "" thought = thought.replace("```", "").strip()
return thought
def _clean_action(self, text: str) -> str: def _clean_action(self, text: str) -> str:
"""Clean action string by removing non-essential formatting characters.""" """Clean action string by removing non-essential formatting characters."""

View File

@@ -216,10 +216,43 @@ MODELS = {
"watsonx/ibm/granite-3-8b-instruct", "watsonx/ibm/granite-3-8b-instruct",
], ],
"bedrock": [ "bedrock": [
"bedrock/us.amazon.nova-pro-v1:0",
"bedrock/us.amazon.nova-micro-v1:0",
"bedrock/us.amazon.nova-lite-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/us.anthropic.claude-3-opus-20240229-v1:0",
"bedrock/us.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/us.meta.llama3-2-11b-instruct-v1:0",
"bedrock/us.meta.llama3-2-3b-instruct-v1:0",
"bedrock/us.meta.llama3-2-90b-instruct-v1:0",
"bedrock/us.meta.llama3-2-1b-instruct-v1:0",
"bedrock/us.meta.llama3-1-8b-instruct-v1:0",
"bedrock/us.meta.llama3-1-70b-instruct-v1:0",
"bedrock/us.meta.llama3-3-70b-instruct-v1:0",
"bedrock/us.meta.llama3-1-405b-instruct-v1:0",
"bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/eu.meta.llama3-2-3b-instruct-v1:0",
"bedrock/eu.meta.llama3-2-1b-instruct-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/amazon.nova-pro-v1:0",
"bedrock/amazon.nova-micro-v1:0",
"bedrock/amazon.nova-lite-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0",
"bedrock/anthropic.claude-3-5-haiku-20241022-v1:0",
"bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0",
"bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0",
"bedrock/anthropic.claude-3-sonnet-20240229-v1:0", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-3-opus-20240229-v1:0", "bedrock/anthropic.claude-3-opus-20240229-v1:0",
"bedrock/anthropic.claude-3-haiku-20240307-v1:0",
"bedrock/anthropic.claude-v2:1", "bedrock/anthropic.claude-v2:1",
"bedrock/anthropic.claude-v2", "bedrock/anthropic.claude-v2",
"bedrock/anthropic.claude-instant-v1", "bedrock/anthropic.claude-instant-v1",
@@ -234,8 +267,6 @@ MODELS = {
"bedrock/ai21.j2-mid-v1", "bedrock/ai21.j2-mid-v1",
"bedrock/ai21.j2-ultra-v1", "bedrock/ai21.j2-ultra-v1",
"bedrock/ai21.jamba-instruct-v1:0", "bedrock/ai21.jamba-instruct-v1:0",
"bedrock/meta.llama2-13b-chat-v1",
"bedrock/meta.llama2-70b-chat-v1",
"bedrock/mistral.mistral-7b-instruct-v0:2", "bedrock/mistral.mistral-7b-instruct-v0:2",
"bedrock/mistral.mixtral-8x7b-instruct-v0:1", "bedrock/mistral.mixtral-8x7b-instruct-v0:1",
], ],

View File

@@ -600,6 +600,7 @@ class Crew(BaseModel):
agent.i18n = i18n agent.i18n = i18n
# type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]" # type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]"
agent.crew = self # type: ignore[attr-defined] agent.crew = self # type: ignore[attr-defined]
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring # TODO: Create an AgentFunctionCalling protocol for future refactoring
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm" if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm" agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"

View File

@@ -894,35 +894,45 @@ class Flow(Generic[T], metaclass=FlowMeta):
Notes Notes
----- -----
- Routers are executed sequentially to maintain flow control - Routers are executed sequentially to maintain flow control
- Each router's result becomes the new trigger_method - Each router's result becomes a new trigger_method
- Normal listeners are executed in parallel for efficiency - Normal listeners are executed in parallel for efficiency
- Listeners can receive the trigger method's result as a parameter - Listeners can receive the trigger method's result as a parameter
""" """
# First, handle routers repeatedly until no router triggers anymore # First, handle routers repeatedly until no router triggers anymore
router_results = []
current_trigger = trigger_method
while True: while True:
routers_triggered = self._find_triggered_methods( routers_triggered = self._find_triggered_methods(
trigger_method, router_only=True current_trigger, router_only=True
) )
if not routers_triggered: if not routers_triggered:
break break
for router_name in routers_triggered: for router_name in routers_triggered:
await self._execute_single_listener(router_name, result) await self._execute_single_listener(router_name, result)
# After executing router, the router's result is the path # After executing router, the router's result is the path
# The last router executed sets the trigger_method router_result = self._method_outputs[-1]
# The router result is the last element in self._method_outputs if router_result: # Only add non-None results
trigger_method = self._method_outputs[-1] router_results.append(router_result)
current_trigger = (
router_result # Update for next iteration of router chain
)
# Now that no more routers are triggered by current trigger_method, # Now execute normal listeners for all router results and the original trigger
# execute normal listeners all_triggers = [trigger_method] + router_results
listeners_triggered = self._find_triggered_methods(
trigger_method, router_only=False for current_trigger in all_triggers:
) if current_trigger: # Skip None results
if listeners_triggered: listeners_triggered = self._find_triggered_methods(
tasks = [ current_trigger, router_only=False
self._execute_single_listener(listener_name, result) )
for listener_name in listeners_triggered if listeners_triggered:
] tasks = [
await asyncio.gather(*tasks) self._execute_single_listener(listener_name, result)
for listener_name in listeners_triggered
]
await asyncio.gather(*tasks)
def _find_triggered_methods( def _find_triggered_methods(
self, trigger_method: str, router_only: bool self, trigger_method: str, router_only: bool

View File

@@ -4,7 +4,7 @@ SQLite-based implementation of flow state persistence.
import json import json
import sqlite3 import sqlite3
from datetime import datetime from datetime import datetime, timezone
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional, Union from typing import Any, Dict, Optional, Union
@@ -34,6 +34,7 @@ class SQLiteFlowPersistence(FlowPersistence):
ValueError: If db_path is invalid ValueError: If db_path is invalid
""" """
from crewai.utilities.paths import db_storage_path from crewai.utilities.paths import db_storage_path
# Get path from argument or default location # Get path from argument or default location
path = db_path or str(Path(db_storage_path()) / "flow_states.db") path = db_path or str(Path(db_storage_path()) / "flow_states.db")
@@ -46,7 +47,8 @@ class SQLiteFlowPersistence(FlowPersistence):
def init_db(self) -> None: def init_db(self) -> None:
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute(""" conn.execute(
"""
CREATE TABLE IF NOT EXISTS flow_states ( CREATE TABLE IF NOT EXISTS flow_states (
id INTEGER PRIMARY KEY AUTOINCREMENT, id INTEGER PRIMARY KEY AUTOINCREMENT,
flow_uuid TEXT NOT NULL, flow_uuid TEXT NOT NULL,
@@ -54,12 +56,15 @@ class SQLiteFlowPersistence(FlowPersistence):
timestamp DATETIME NOT NULL, timestamp DATETIME NOT NULL,
state_json TEXT NOT NULL state_json TEXT NOT NULL
) )
""") """
)
# Add index for faster UUID lookups # Add index for faster UUID lookups
conn.execute(""" conn.execute(
"""
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
ON flow_states(flow_uuid) ON flow_states(flow_uuid)
""") """
)
def save_state( def save_state(
self, self,
@@ -85,19 +90,22 @@ class SQLiteFlowPersistence(FlowPersistence):
) )
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
conn.execute(""" conn.execute(
"""
INSERT INTO flow_states ( INSERT INTO flow_states (
flow_uuid, flow_uuid,
method_name, method_name,
timestamp, timestamp,
state_json state_json
) VALUES (?, ?, ?, ?) ) VALUES (?, ?, ?, ?)
""", ( """,
flow_uuid, (
method_name, flow_uuid,
datetime.utcnow().isoformat(), method_name,
json.dumps(state_dict), datetime.now(timezone.utc).isoformat(),
)) json.dumps(state_dict),
),
)
def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]:
"""Load the most recent state for a given flow UUID. """Load the most recent state for a given flow UUID.
@@ -109,13 +117,16 @@ class SQLiteFlowPersistence(FlowPersistence):
The most recent state as a dictionary, or None if no state exists The most recent state as a dictionary, or None if no state exists
""" """
with sqlite3.connect(self.db_path) as conn: with sqlite3.connect(self.db_path) as conn:
cursor = conn.execute(""" cursor = conn.execute(
"""
SELECT state_json SELECT state_json
FROM flow_states FROM flow_states
WHERE flow_uuid = ? WHERE flow_uuid = ?
ORDER BY id DESC ORDER BY id DESC
LIMIT 1 LIMIT 1
""", (flow_uuid,)) """,
(flow_uuid,),
)
row = cursor.fetchone() row = cursor.fetchone()
if row: if row:

View File

@@ -16,7 +16,8 @@ Example
import ast import ast
import inspect import inspect
import textwrap import textwrap
from typing import Any, Dict, List, Optional, Set, Union from collections import defaultdict, deque
from typing import Any, Deque, Dict, List, Optional, Set, Union
def get_possible_return_constants(function: Any) -> Optional[List[str]]: def get_possible_return_constants(function: Any) -> Optional[List[str]]:
@@ -118,7 +119,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
- Processes router paths separately - Processes router paths separately
""" """
levels: Dict[str, int] = {} levels: Dict[str, int] = {}
queue: List[str] = [] queue: Deque[str] = deque()
visited: Set[str] = set() visited: Set[str] = set()
pending_and_listeners: Dict[str, Set[str]] = {} pending_and_listeners: Dict[str, Set[str]] = {}
@@ -128,28 +129,35 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
levels[method_name] = 0 levels[method_name] = 0
queue.append(method_name) queue.append(method_name)
# Precompute listener dependencies
or_listeners = defaultdict(list)
and_listeners = defaultdict(set)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items():
if condition_type == "OR":
for method in trigger_methods:
or_listeners[method].append(listener_name)
elif condition_type == "AND":
and_listeners[listener_name] = set(trigger_methods)
# Breadth-first traversal to assign levels # Breadth-first traversal to assign levels
while queue: while queue:
current = queue.pop(0) current = queue.popleft()
current_level = levels[current] current_level = levels[current]
visited.add(current) visited.add(current)
for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): for listener_name in or_listeners[current]:
if condition_type == "OR": if listener_name not in levels or levels[listener_name] > current_level + 1:
if current in trigger_methods: levels[listener_name] = current_level + 1
if ( if listener_name not in visited:
listener_name not in levels queue.append(listener_name)
or levels[listener_name] > current_level + 1
): for listener_name, required_methods in and_listeners.items():
levels[listener_name] = current_level + 1 if current in required_methods:
if listener_name not in visited:
queue.append(listener_name)
elif condition_type == "AND":
if listener_name not in pending_and_listeners: if listener_name not in pending_and_listeners:
pending_and_listeners[listener_name] = set() pending_and_listeners[listener_name] = set()
if current in trigger_methods: pending_and_listeners[listener_name].add(current)
pending_and_listeners[listener_name].add(current)
if set(trigger_methods) == pending_and_listeners[listener_name]: if required_methods == pending_and_listeners[listener_name]:
if ( if (
listener_name not in levels listener_name not in levels
or levels[listener_name] > current_level + 1 or levels[listener_name] > current_level + 1
@@ -159,22 +167,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]:
queue.append(listener_name) queue.append(listener_name)
# Handle router connections # Handle router connections
if current in flow._routers: process_router_paths(flow, current, current_level, levels, queue)
router_method_name = current
paths = flow._router_paths.get(router_method_name, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
if listener_name not in visited:
queue.append(listener_name)
return levels return levels
@@ -227,10 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]:
def dfs_ancestors( def dfs_ancestors(
node: str, node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any
ancestors: Dict[str, Set[str]],
visited: Set[str],
flow: Any
) -> None: ) -> None:
""" """
Perform depth-first search to build ancestor relationships. Perform depth-first search to build ancestor relationships.
@@ -274,7 +264,9 @@ def dfs_ancestors(
dfs_ancestors(listener_name, ancestors, visited, flow) dfs_ancestors(listener_name, ancestors, visited, flow)
def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: def is_ancestor(
node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]
) -> bool:
""" """
Check if one node is an ancestor of another. Check if one node is an ancestor of another.
@@ -339,7 +331,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]:
return parent_children return parent_children
def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int: def get_child_index(
parent: str, child: str, parent_children: Dict[str, List[str]]
) -> int:
""" """
Get the index of a child node in its parent's sorted children list. Get the index of a child node in its parent's sorted children list.
@@ -360,3 +354,23 @@ def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str
children = parent_children.get(parent, []) children = parent_children.get(parent, [])
children.sort() children.sort()
return children.index(child) return children.index(child)
def process_router_paths(flow, current, current_level, levels, queue):
"""
Handle the router connections for the current node.
"""
if current in flow._routers:
paths = flow._router_paths.get(current, [])
for path in paths:
for listener_name, (
condition_type,
trigger_methods,
) in flow._listeners.items():
if path in trigger_methods:
if (
listener_name not in levels
or levels[listener_name] > current_level + 1
):
levels[listener_name] = current_level + 1
queue.append(listener_name)

View File

@@ -64,6 +64,7 @@ LLM_CONTEXT_WINDOW_SIZES = {
"gpt-4-turbo": 128000, "gpt-4-turbo": 128000,
"o1-preview": 128000, "o1-preview": 128000,
"o1-mini": 128000, "o1-mini": 128000,
"o3-mini": 200000, # Based on official o3-mini specifications
# gemini # gemini
"gemini-2.0-flash": 1048576, "gemini-2.0-flash": 1048576,
"gemini-1.5-pro": 2097152, "gemini-1.5-pro": 2097152,
@@ -485,10 +486,23 @@ class LLM:
""" """
Returns the context window size, using 75% of the maximum to avoid Returns the context window size, using 75% of the maximum to avoid
cutting off messages mid-thread. cutting off messages mid-thread.
Raises:
ValueError: If a model's context window size is outside valid bounds (1024-2097152)
""" """
if self.context_window_size != 0: if self.context_window_size != 0:
return self.context_window_size return self.context_window_size
MIN_CONTEXT = 1024
MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro
# Validate all context window sizes
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
if value < MIN_CONTEXT or value > MAX_CONTEXT:
raise ValueError(
f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}"
)
self.context_window_size = int( self.context_window_size = int(
DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO
) )

View File

@@ -44,6 +44,7 @@ def create_llm(
# Extract attributes with explicit types # Extract attributes with explicit types
model = ( model = (
getattr(llm_value, "model_name", None) getattr(llm_value, "model_name", None)
or getattr(llm_value, "model", None)
or getattr(llm_value, "deployment_name", None) or getattr(llm_value, "deployment_name", None)
or str(llm_value) or str(llm_value)
) )

View File

@@ -30,8 +30,14 @@ class TokenCalcHandler(CustomLogger):
if hasattr(usage, "prompt_tokens"): if hasattr(usage, "prompt_tokens"):
self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens) self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens)
if hasattr(usage, "completion_tokens"): if hasattr(usage, "completion_tokens"):
self.token_cost_process.sum_completion_tokens(usage.completion_tokens) self.token_cost_process.sum_completion_tokens(
if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details: usage.completion_tokens
)
if (
hasattr(usage, "prompt_tokens_details")
and usage.prompt_tokens_details
and usage.prompt_tokens_details.cached_tokens
):
self.token_cost_process.sum_cached_prompt_tokens( self.token_cost_process.sum_cached_prompt_tokens(
usage.prompt_tokens_details.cached_tokens usage.prompt_tokens_details.cached_tokens
) )

View File

@@ -654,3 +654,104 @@ def test_flow_plotting():
assert isinstance(received_events[0], FlowPlotEvent) assert isinstance(received_events[0], FlowPlotEvent)
assert received_events[0].flow_name == "StatelessFlow" assert received_events[0].flow_name == "StatelessFlow"
assert isinstance(received_events[0].timestamp, datetime) assert isinstance(received_events[0].timestamp, datetime)
def test_multiple_routers_from_same_trigger():
"""Test that multiple routers triggered by the same method all activate their listeners."""
execution_order = []
class MultiRouterFlow(Flow):
def __init__(self):
super().__init__()
# Set diagnosed conditions to trigger all routers
self.state["diagnosed_conditions"] = "DHA" # Contains D, H, and A
@start()
def scan_medical(self):
execution_order.append("scan_medical")
return "scan_complete"
@router(scan_medical)
def diagnose_conditions(self):
execution_order.append("diagnose_conditions")
return "diagnosis_complete"
@router(diagnose_conditions)
def diabetes_router(self):
execution_order.append("diabetes_router")
if "D" in self.state["diagnosed_conditions"]:
return "diabetes"
return None
@listen("diabetes")
def diabetes_analysis(self):
execution_order.append("diabetes_analysis")
return "diabetes_analysis_complete"
@router(diagnose_conditions)
def hypertension_router(self):
execution_order.append("hypertension_router")
if "H" in self.state["diagnosed_conditions"]:
return "hypertension"
return None
@listen("hypertension")
def hypertension_analysis(self):
execution_order.append("hypertension_analysis")
return "hypertension_analysis_complete"
@router(diagnose_conditions)
def anemia_router(self):
execution_order.append("anemia_router")
if "A" in self.state["diagnosed_conditions"]:
return "anemia"
return None
@listen("anemia")
def anemia_analysis(self):
execution_order.append("anemia_analysis")
return "anemia_analysis_complete"
flow = MultiRouterFlow()
flow.kickoff()
# Verify all methods were called
assert "scan_medical" in execution_order
assert "diagnose_conditions" in execution_order
# Verify all routers were called
assert "diabetes_router" in execution_order
assert "hypertension_router" in execution_order
assert "anemia_router" in execution_order
# Verify all listeners were called - this is the key test for the fix
assert "diabetes_analysis" in execution_order
assert "hypertension_analysis" in execution_order
assert "anemia_analysis" in execution_order
# Verify execution order constraints
assert execution_order.index("diagnose_conditions") > execution_order.index(
"scan_medical"
)
# All routers should execute after diagnose_conditions
assert execution_order.index("diabetes_router") > execution_order.index(
"diagnose_conditions"
)
assert execution_order.index("hypertension_router") > execution_order.index(
"diagnose_conditions"
)
assert execution_order.index("anemia_router") > execution_order.index(
"diagnose_conditions"
)
# All analyses should execute after their respective routers
assert execution_order.index("diabetes_analysis") > execution_order.index(
"diabetes_router"
)
assert execution_order.index("hypertension_analysis") > execution_order.index(
"hypertension_router"
)
assert execution_order.index("anemia_analysis") > execution_order.index(
"anemia_router"
)

View File

@@ -6,7 +6,7 @@ import pytest
from pydantic import BaseModel from pydantic import BaseModel
from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess
from crewai.llm import LLM from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM
from crewai.utilities.events import crewai_event_bus from crewai.utilities.events import crewai_event_bus
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
from crewai.utilities.token_counter_callback import TokenCalcHandler from crewai.utilities.token_counter_callback import TokenCalcHandler
@@ -285,6 +285,23 @@ def test_o3_mini_reasoning_effort_medium():
assert isinstance(result, str) assert isinstance(result, str)
assert "Paris" in result assert "Paris" in result
def test_context_window_validation():
"""Test that context window validation works correctly."""
# Test valid window size
llm = LLM(model="o3-mini")
assert llm.get_context_window_size() == int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
# Test invalid window size
with pytest.raises(ValueError) as excinfo:
with patch.dict(
"crewai.llm.LLM_CONTEXT_WINDOW_SIZES",
{"test-model": 500}, # Below minimum
clear=True,
):
llm = LLM(model="test-model")
llm.get_context_window_size()
assert "must be between 1024 and 2097152" in str(excinfo.value)
@pytest.mark.vcr(filter_headers=["authorization"]) @pytest.mark.vcr(filter_headers=["authorization"])
@pytest.fixture @pytest.fixture