diff --git a/src/crewai/agent.py b/src/crewai/agent.py index f07408133..cfebc18e5 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -114,7 +114,6 @@ class Agent(BaseAgent): @model_validator(mode="after") def post_init_setup(self): - self._set_knowledge() self.agent_ops_agent_name = self.role self.llm = create_llm(self.llm) @@ -134,8 +133,11 @@ class Agent(BaseAgent): self.cache_handler = CacheHandler() self.set_cache_handler(self.cache_handler) - def _set_knowledge(self): + def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): try: + if self.embedder is None and crew_embedder: + self.embedder = crew_embedder + if self.knowledge_sources: full_pattern = re.compile(r"[^a-zA-Z0-9\-_\r\n]|(\.\.)") knowledge_agent_name = f"{re.sub(full_pattern, '_', self.role)}" diff --git a/src/crewai/agents/agent_builder/base_agent.py b/src/crewai/agents/agent_builder/base_agent.py index 64110c2ae..f39fafb99 100644 --- a/src/crewai/agents/agent_builder/base_agent.py +++ b/src/crewai/agents/agent_builder/base_agent.py @@ -351,3 +351,6 @@ class BaseAgent(ABC, BaseModel): if not self._rpm_controller: self._rpm_controller = rpm_controller self.create_agent_executor() + + def set_knowledge(self, crew_embedder: Optional[Dict[str, Any]] = None): + pass diff --git a/src/crewai/agents/parser.py b/src/crewai/agents/parser.py index 71444a20a..1bda4df5c 100644 --- a/src/crewai/agents/parser.py +++ b/src/crewai/agents/parser.py @@ -124,14 +124,15 @@ class CrewAgentParser: ) def _extract_thought(self, text: str) -> str: - regex = r"(.*?)(?:\n\nAction|\n\nFinal Answer)" - thought_match = re.search(regex, text, re.DOTALL) - if thought_match: - thought = thought_match.group(1).strip() - # Remove any triple backticks from the thought string - thought = thought.replace("```", "").strip() - return thought - return "" + thought_index = text.find("\n\nAction") + if thought_index == -1: + thought_index = text.find("\n\nFinal Answer") + if thought_index == -1: + return "" + thought = text[:thought_index].strip() + # Remove any triple backticks from the thought string + thought = thought.replace("```", "").strip() + return thought def _clean_action(self, text: str) -> str: """Clean action string by removing non-essential formatting characters.""" diff --git a/src/crewai/cli/constants.py b/src/crewai/cli/constants.py index b97b4f208..fec0b6384 100644 --- a/src/crewai/cli/constants.py +++ b/src/crewai/cli/constants.py @@ -216,10 +216,43 @@ MODELS = { "watsonx/ibm/granite-3-8b-instruct", ], "bedrock": [ + "bedrock/us.amazon.nova-pro-v1:0", + "bedrock/us.amazon.nova-micro-v1:0", + "bedrock/us.amazon.nova-lite-v1:0", + "bedrock/us.anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/us.anthropic.claude-3-5-haiku-20241022-v1:0", + "bedrock/us.anthropic.claude-3-5-sonnet-20241022-v2:0", + "bedrock/us.anthropic.claude-3-7-sonnet-20250219-v1:0", + "bedrock/us.anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/us.anthropic.claude-3-opus-20240229-v1:0", + "bedrock/us.anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/us.meta.llama3-2-11b-instruct-v1:0", + "bedrock/us.meta.llama3-2-3b-instruct-v1:0", + "bedrock/us.meta.llama3-2-90b-instruct-v1:0", + "bedrock/us.meta.llama3-2-1b-instruct-v1:0", + "bedrock/us.meta.llama3-1-8b-instruct-v1:0", + "bedrock/us.meta.llama3-1-70b-instruct-v1:0", + "bedrock/us.meta.llama3-3-70b-instruct-v1:0", + "bedrock/us.meta.llama3-1-405b-instruct-v1:0", + "bedrock/eu.anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/eu.anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/eu.anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/eu.meta.llama3-2-3b-instruct-v1:0", + "bedrock/eu.meta.llama3-2-1b-instruct-v1:0", + "bedrock/apac.anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/apac.anthropic.claude-3-5-sonnet-20241022-v2:0", + "bedrock/apac.anthropic.claude-3-sonnet-20240229-v1:0", + "bedrock/apac.anthropic.claude-3-haiku-20240307-v1:0", + "bedrock/amazon.nova-pro-v1:0", + "bedrock/amazon.nova-micro-v1:0", + "bedrock/amazon.nova-lite-v1:0", "bedrock/anthropic.claude-3-5-sonnet-20240620-v1:0", + "bedrock/anthropic.claude-3-5-haiku-20241022-v1:0", + "bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0", + "bedrock/anthropic.claude-3-7-sonnet-20250219-v1:0", "bedrock/anthropic.claude-3-sonnet-20240229-v1:0", - "bedrock/anthropic.claude-3-haiku-20240307-v1:0", "bedrock/anthropic.claude-3-opus-20240229-v1:0", + "bedrock/anthropic.claude-3-haiku-20240307-v1:0", "bedrock/anthropic.claude-v2:1", "bedrock/anthropic.claude-v2", "bedrock/anthropic.claude-instant-v1", @@ -234,8 +267,6 @@ MODELS = { "bedrock/ai21.j2-mid-v1", "bedrock/ai21.j2-ultra-v1", "bedrock/ai21.jamba-instruct-v1:0", - "bedrock/meta.llama2-13b-chat-v1", - "bedrock/meta.llama2-70b-chat-v1", "bedrock/mistral.mistral-7b-instruct-v0:2", "bedrock/mistral.mixtral-8x7b-instruct-v0:1", ], diff --git a/src/crewai/crew.py b/src/crewai/crew.py index cf627700e..9cecfed3a 100644 --- a/src/crewai/crew.py +++ b/src/crewai/crew.py @@ -600,6 +600,7 @@ class Crew(BaseModel): agent.i18n = i18n # type: ignore[attr-defined] # Argument 1 to "_interpolate_inputs" of "Crew" has incompatible type "dict[str, Any] | None"; expected "dict[str, Any]" agent.crew = self # type: ignore[attr-defined] + agent.set_knowledge(crew_embedder=self.embedder) # TODO: Create an AgentFunctionCalling protocol for future refactoring if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm" agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm" diff --git a/src/crewai/flow/flow.py b/src/crewai/flow/flow.py index 7a8b88ba0..3b6e81293 100644 --- a/src/crewai/flow/flow.py +++ b/src/crewai/flow/flow.py @@ -894,35 +894,45 @@ class Flow(Generic[T], metaclass=FlowMeta): Notes ----- - Routers are executed sequentially to maintain flow control - - Each router's result becomes the new trigger_method + - Each router's result becomes a new trigger_method - Normal listeners are executed in parallel for efficiency - Listeners can receive the trigger method's result as a parameter """ # First, handle routers repeatedly until no router triggers anymore + router_results = [] + current_trigger = trigger_method + while True: routers_triggered = self._find_triggered_methods( - trigger_method, router_only=True + current_trigger, router_only=True ) if not routers_triggered: break + for router_name in routers_triggered: await self._execute_single_listener(router_name, result) # After executing router, the router's result is the path - # The last router executed sets the trigger_method - # The router result is the last element in self._method_outputs - trigger_method = self._method_outputs[-1] + router_result = self._method_outputs[-1] + if router_result: # Only add non-None results + router_results.append(router_result) + current_trigger = ( + router_result # Update for next iteration of router chain + ) - # Now that no more routers are triggered by current trigger_method, - # execute normal listeners - listeners_triggered = self._find_triggered_methods( - trigger_method, router_only=False - ) - if listeners_triggered: - tasks = [ - self._execute_single_listener(listener_name, result) - for listener_name in listeners_triggered - ] - await asyncio.gather(*tasks) + # Now execute normal listeners for all router results and the original trigger + all_triggers = [trigger_method] + router_results + + for current_trigger in all_triggers: + if current_trigger: # Skip None results + listeners_triggered = self._find_triggered_methods( + current_trigger, router_only=False + ) + if listeners_triggered: + tasks = [ + self._execute_single_listener(listener_name, result) + for listener_name in listeners_triggered + ] + await asyncio.gather(*tasks) def _find_triggered_methods( self, trigger_method: str, router_only: bool diff --git a/src/crewai/flow/persistence/sqlite.py b/src/crewai/flow/persistence/sqlite.py index 7a6f134fa..21e906afd 100644 --- a/src/crewai/flow/persistence/sqlite.py +++ b/src/crewai/flow/persistence/sqlite.py @@ -4,7 +4,7 @@ SQLite-based implementation of flow state persistence. import json import sqlite3 -from datetime import datetime +from datetime import datetime, timezone from pathlib import Path from typing import Any, Dict, Optional, Union @@ -34,6 +34,7 @@ class SQLiteFlowPersistence(FlowPersistence): ValueError: If db_path is invalid """ from crewai.utilities.paths import db_storage_path + # Get path from argument or default location path = db_path or str(Path(db_storage_path()) / "flow_states.db") @@ -46,7 +47,8 @@ class SQLiteFlowPersistence(FlowPersistence): def init_db(self) -> None: """Create the necessary tables if they don't exist.""" with sqlite3.connect(self.db_path) as conn: - conn.execute(""" + conn.execute( + """ CREATE TABLE IF NOT EXISTS flow_states ( id INTEGER PRIMARY KEY AUTOINCREMENT, flow_uuid TEXT NOT NULL, @@ -54,12 +56,15 @@ class SQLiteFlowPersistence(FlowPersistence): timestamp DATETIME NOT NULL, state_json TEXT NOT NULL ) - """) + """ + ) # Add index for faster UUID lookups - conn.execute(""" + conn.execute( + """ CREATE INDEX IF NOT EXISTS idx_flow_states_uuid ON flow_states(flow_uuid) - """) + """ + ) def save_state( self, @@ -85,19 +90,22 @@ class SQLiteFlowPersistence(FlowPersistence): ) with sqlite3.connect(self.db_path) as conn: - conn.execute(""" + conn.execute( + """ INSERT INTO flow_states ( flow_uuid, method_name, timestamp, state_json ) VALUES (?, ?, ?, ?) - """, ( - flow_uuid, - method_name, - datetime.utcnow().isoformat(), - json.dumps(state_dict), - )) + """, + ( + flow_uuid, + method_name, + datetime.now(timezone.utc).isoformat(), + json.dumps(state_dict), + ), + ) def load_state(self, flow_uuid: str) -> Optional[Dict[str, Any]]: """Load the most recent state for a given flow UUID. @@ -109,13 +117,16 @@ class SQLiteFlowPersistence(FlowPersistence): The most recent state as a dictionary, or None if no state exists """ with sqlite3.connect(self.db_path) as conn: - cursor = conn.execute(""" + cursor = conn.execute( + """ SELECT state_json FROM flow_states WHERE flow_uuid = ? ORDER BY id DESC LIMIT 1 - """, (flow_uuid,)) + """, + (flow_uuid,), + ) row = cursor.fetchone() if row: diff --git a/src/crewai/flow/utils.py b/src/crewai/flow/utils.py index c0686222f..81f3c1041 100644 --- a/src/crewai/flow/utils.py +++ b/src/crewai/flow/utils.py @@ -16,7 +16,8 @@ Example import ast import inspect import textwrap -from typing import Any, Dict, List, Optional, Set, Union +from collections import defaultdict, deque +from typing import Any, Deque, Dict, List, Optional, Set, Union def get_possible_return_constants(function: Any) -> Optional[List[str]]: @@ -118,7 +119,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: - Processes router paths separately """ levels: Dict[str, int] = {} - queue: List[str] = [] + queue: Deque[str] = deque() visited: Set[str] = set() pending_and_listeners: Dict[str, Set[str]] = {} @@ -128,28 +129,35 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: levels[method_name] = 0 queue.append(method_name) + # Precompute listener dependencies + or_listeners = defaultdict(list) + and_listeners = defaultdict(set) + for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): + if condition_type == "OR": + for method in trigger_methods: + or_listeners[method].append(listener_name) + elif condition_type == "AND": + and_listeners[listener_name] = set(trigger_methods) + # Breadth-first traversal to assign levels while queue: - current = queue.pop(0) + current = queue.popleft() current_level = levels[current] visited.add(current) - for listener_name, (condition_type, trigger_methods) in flow._listeners.items(): - if condition_type == "OR": - if current in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) - elif condition_type == "AND": + for listener_name in or_listeners[current]: + if listener_name not in levels or levels[listener_name] > current_level + 1: + levels[listener_name] = current_level + 1 + if listener_name not in visited: + queue.append(listener_name) + + for listener_name, required_methods in and_listeners.items(): + if current in required_methods: if listener_name not in pending_and_listeners: pending_and_listeners[listener_name] = set() - if current in trigger_methods: - pending_and_listeners[listener_name].add(current) - if set(trigger_methods) == pending_and_listeners[listener_name]: + pending_and_listeners[listener_name].add(current) + + if required_methods == pending_and_listeners[listener_name]: if ( listener_name not in levels or levels[listener_name] > current_level + 1 @@ -159,22 +167,7 @@ def calculate_node_levels(flow: Any) -> Dict[str, int]: queue.append(listener_name) # Handle router connections - if current in flow._routers: - router_method_name = current - paths = flow._router_paths.get(router_method_name, []) - for path in paths: - for listener_name, ( - condition_type, - trigger_methods, - ) in flow._listeners.items(): - if path in trigger_methods: - if ( - listener_name not in levels - or levels[listener_name] > current_level + 1 - ): - levels[listener_name] = current_level + 1 - if listener_name not in visited: - queue.append(listener_name) + process_router_paths(flow, current, current_level, levels, queue) return levels @@ -227,10 +220,7 @@ def build_ancestor_dict(flow: Any) -> Dict[str, Set[str]]: def dfs_ancestors( - node: str, - ancestors: Dict[str, Set[str]], - visited: Set[str], - flow: Any + node: str, ancestors: Dict[str, Set[str]], visited: Set[str], flow: Any ) -> None: """ Perform depth-first search to build ancestor relationships. @@ -274,7 +264,9 @@ def dfs_ancestors( dfs_ancestors(listener_name, ancestors, visited, flow) -def is_ancestor(node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]]) -> bool: +def is_ancestor( + node: str, ancestor_candidate: str, ancestors: Dict[str, Set[str]] +) -> bool: """ Check if one node is an ancestor of another. @@ -339,7 +331,9 @@ def build_parent_children_dict(flow: Any) -> Dict[str, List[str]]: return parent_children -def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str]]) -> int: +def get_child_index( + parent: str, child: str, parent_children: Dict[str, List[str]] +) -> int: """ Get the index of a child node in its parent's sorted children list. @@ -360,3 +354,23 @@ def get_child_index(parent: str, child: str, parent_children: Dict[str, List[str children = parent_children.get(parent, []) children.sort() return children.index(child) + + +def process_router_paths(flow, current, current_level, levels, queue): + """ + Handle the router connections for the current node. + """ + if current in flow._routers: + paths = flow._router_paths.get(current, []) + for path in paths: + for listener_name, ( + condition_type, + trigger_methods, + ) in flow._listeners.items(): + if path in trigger_methods: + if ( + listener_name not in levels + or levels[listener_name] > current_level + 1 + ): + levels[listener_name] = current_level + 1 + queue.append(listener_name) diff --git a/src/crewai/llm.py b/src/crewai/llm.py index 2eefa8934..0c8a46214 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -64,6 +64,7 @@ LLM_CONTEXT_WINDOW_SIZES = { "gpt-4-turbo": 128000, "o1-preview": 128000, "o1-mini": 128000, + "o3-mini": 200000, # Based on official o3-mini specifications # gemini "gemini-2.0-flash": 1048576, "gemini-1.5-pro": 2097152, @@ -485,10 +486,23 @@ class LLM: """ Returns the context window size, using 75% of the maximum to avoid cutting off messages mid-thread. + + Raises: + ValueError: If a model's context window size is outside valid bounds (1024-2097152) """ if self.context_window_size != 0: return self.context_window_size + MIN_CONTEXT = 1024 + MAX_CONTEXT = 2097152 # Current max from gemini-1.5-pro + + # Validate all context window sizes + for key, value in LLM_CONTEXT_WINDOW_SIZES.items(): + if value < MIN_CONTEXT or value > MAX_CONTEXT: + raise ValueError( + f"Context window for {key} must be between {MIN_CONTEXT} and {MAX_CONTEXT}" + ) + self.context_window_size = int( DEFAULT_CONTEXT_WINDOW_SIZE * CONTEXT_WINDOW_USAGE_RATIO ) diff --git a/src/crewai/utilities/llm_utils.py b/src/crewai/utilities/llm_utils.py index c774a71fb..4d34d789c 100644 --- a/src/crewai/utilities/llm_utils.py +++ b/src/crewai/utilities/llm_utils.py @@ -44,6 +44,7 @@ def create_llm( # Extract attributes with explicit types model = ( getattr(llm_value, "model_name", None) + or getattr(llm_value, "model", None) or getattr(llm_value, "deployment_name", None) or str(llm_value) ) diff --git a/src/crewai/utilities/token_counter_callback.py b/src/crewai/utilities/token_counter_callback.py index e612fcae4..7037ad5c4 100644 --- a/src/crewai/utilities/token_counter_callback.py +++ b/src/crewai/utilities/token_counter_callback.py @@ -30,8 +30,14 @@ class TokenCalcHandler(CustomLogger): if hasattr(usage, "prompt_tokens"): self.token_cost_process.sum_prompt_tokens(usage.prompt_tokens) if hasattr(usage, "completion_tokens"): - self.token_cost_process.sum_completion_tokens(usage.completion_tokens) - if hasattr(usage, "prompt_tokens_details") and usage.prompt_tokens_details: + self.token_cost_process.sum_completion_tokens( + usage.completion_tokens + ) + if ( + hasattr(usage, "prompt_tokens_details") + and usage.prompt_tokens_details + and usage.prompt_tokens_details.cached_tokens + ): self.token_cost_process.sum_cached_prompt_tokens( usage.prompt_tokens_details.cached_tokens ) diff --git a/tests/flow_test.py b/tests/flow_test.py index b2edcfa5a..c2640fffb 100644 --- a/tests/flow_test.py +++ b/tests/flow_test.py @@ -654,3 +654,104 @@ def test_flow_plotting(): assert isinstance(received_events[0], FlowPlotEvent) assert received_events[0].flow_name == "StatelessFlow" assert isinstance(received_events[0].timestamp, datetime) + + +def test_multiple_routers_from_same_trigger(): + """Test that multiple routers triggered by the same method all activate their listeners.""" + execution_order = [] + + class MultiRouterFlow(Flow): + def __init__(self): + super().__init__() + # Set diagnosed conditions to trigger all routers + self.state["diagnosed_conditions"] = "DHA" # Contains D, H, and A + + @start() + def scan_medical(self): + execution_order.append("scan_medical") + return "scan_complete" + + @router(scan_medical) + def diagnose_conditions(self): + execution_order.append("diagnose_conditions") + return "diagnosis_complete" + + @router(diagnose_conditions) + def diabetes_router(self): + execution_order.append("diabetes_router") + if "D" in self.state["diagnosed_conditions"]: + return "diabetes" + return None + + @listen("diabetes") + def diabetes_analysis(self): + execution_order.append("diabetes_analysis") + return "diabetes_analysis_complete" + + @router(diagnose_conditions) + def hypertension_router(self): + execution_order.append("hypertension_router") + if "H" in self.state["diagnosed_conditions"]: + return "hypertension" + return None + + @listen("hypertension") + def hypertension_analysis(self): + execution_order.append("hypertension_analysis") + return "hypertension_analysis_complete" + + @router(diagnose_conditions) + def anemia_router(self): + execution_order.append("anemia_router") + if "A" in self.state["diagnosed_conditions"]: + return "anemia" + return None + + @listen("anemia") + def anemia_analysis(self): + execution_order.append("anemia_analysis") + return "anemia_analysis_complete" + + flow = MultiRouterFlow() + flow.kickoff() + + # Verify all methods were called + assert "scan_medical" in execution_order + assert "diagnose_conditions" in execution_order + + # Verify all routers were called + assert "diabetes_router" in execution_order + assert "hypertension_router" in execution_order + assert "anemia_router" in execution_order + + # Verify all listeners were called - this is the key test for the fix + assert "diabetes_analysis" in execution_order + assert "hypertension_analysis" in execution_order + assert "anemia_analysis" in execution_order + + # Verify execution order constraints + assert execution_order.index("diagnose_conditions") > execution_order.index( + "scan_medical" + ) + + # All routers should execute after diagnose_conditions + assert execution_order.index("diabetes_router") > execution_order.index( + "diagnose_conditions" + ) + assert execution_order.index("hypertension_router") > execution_order.index( + "diagnose_conditions" + ) + assert execution_order.index("anemia_router") > execution_order.index( + "diagnose_conditions" + ) + + # All analyses should execute after their respective routers + assert execution_order.index("diabetes_analysis") > execution_order.index( + "diabetes_router" + ) + assert execution_order.index("hypertension_analysis") > execution_order.index( + "hypertension_router" + ) + assert execution_order.index("anemia_analysis") > execution_order.index( + "anemia_router" + ) diff --git a/tests/llm_test.py b/tests/llm_test.py index 00bb69aa5..61aa1aced 100644 --- a/tests/llm_test.py +++ b/tests/llm_test.py @@ -6,7 +6,7 @@ import pytest from pydantic import BaseModel from crewai.agents.agent_builder.utilities.base_token_process import TokenProcess -from crewai.llm import LLM +from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO, LLM from crewai.utilities.events import crewai_event_bus from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent from crewai.utilities.token_counter_callback import TokenCalcHandler @@ -285,6 +285,23 @@ def test_o3_mini_reasoning_effort_medium(): assert isinstance(result, str) assert "Paris" in result +def test_context_window_validation(): + """Test that context window validation works correctly.""" + # Test valid window size + llm = LLM(model="o3-mini") + assert llm.get_context_window_size() == int(200000 * CONTEXT_WINDOW_USAGE_RATIO) + + # Test invalid window size + with pytest.raises(ValueError) as excinfo: + with patch.dict( + "crewai.llm.LLM_CONTEXT_WINDOW_SIZES", + {"test-model": 500}, # Below minimum + clear=True, + ): + llm = LLM(model="test-model") + llm.get_context_window_size() + assert "must be between 1024 and 2097152" in str(excinfo.value) + @pytest.mark.vcr(filter_headers=["authorization"]) @pytest.fixture