mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-07-01 13:18:10 +00:00
Enhance Bedrock request handling and update test cassettes
- Introduced a custom matcher for Bedrock requests to normalize regional endpoints, ensuring consistent behavior across AWS regions. - Updated the VCR configuration to utilize the new matcher. - Adjusted test cassette to replace the original Bedrock endpoint with a placeholder for improved testing consistency. - Modified response body and headers in the test cassette to reflect updated expected values.
This commit is contained in:
@@ -186,8 +186,6 @@ class Telemetry:
|
||||
|
||||
self._safe_telemetry_procedure(_operation)
|
||||
|
||||
# --- CLI-facing spans ---------------------------------------------------
|
||||
|
||||
def deploy_signup_error_span(self) -> None:
|
||||
"""Records when an error occurs during the deployment signup process."""
|
||||
|
||||
|
||||
@@ -1129,7 +1129,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Check if model supports native function calling
|
||||
use_native_tools = (
|
||||
hasattr(self.llm, "supports_function_calling")
|
||||
and callable(getattr(self.llm, "supports_function_calling", None))
|
||||
@@ -1140,7 +1139,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
if use_native_tools:
|
||||
return await self._ainvoke_loop_native_tools()
|
||||
|
||||
# Fall back to ReAct text-based pattern
|
||||
return await self._ainvoke_loop_react()
|
||||
|
||||
async def _ainvoke_loop_react(self) -> AgentFinish:
|
||||
@@ -1289,7 +1287,6 @@ class CrewAgentExecutor(BaseAgentExecutor):
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
# Convert tools to OpenAI schema format
|
||||
if not self.original_tools:
|
||||
return await self._ainvoke_loop_native_no_tools()
|
||||
|
||||
|
||||
@@ -211,8 +211,6 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
formatted_answer, feedback, context
|
||||
)
|
||||
|
||||
# ── Sync helpers ──────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _handle_training_feedback(
|
||||
initial_answer: AgentFinish,
|
||||
@@ -265,8 +263,6 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
|
||||
return answer
|
||||
|
||||
# ── Async helpers ─────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
async def _handle_training_feedback_async(
|
||||
initial_answer: AgentFinish,
|
||||
@@ -319,8 +315,6 @@ class SyncHumanInputProvider(HumanInputProvider):
|
||||
|
||||
return answer
|
||||
|
||||
# ── I/O ───────────────────────────────────────────────────────────
|
||||
|
||||
@staticmethod
|
||||
def _prompt_input(crew: Crew | None) -> str:
|
||||
"""Show rich panel and prompt for input.
|
||||
|
||||
@@ -554,7 +554,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
stack.append((self._kickoff_event_id, "crew_kickoff_started"))
|
||||
restore_event_scope(tuple(stack))
|
||||
|
||||
# Restore last_event_id and emission counter from the record
|
||||
last_event_id: str | None = None
|
||||
max_seq = 0
|
||||
for node in state.event_record.nodes.values():
|
||||
@@ -613,7 +612,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._cache_handler = CacheHandler()
|
||||
event_listener = EventListener()
|
||||
|
||||
# Determine and set tracing state once for this execution
|
||||
tracing_enabled = should_enable_tracing(override=self.tracing)
|
||||
set_tracing_enabled(tracing_enabled)
|
||||
|
||||
@@ -641,7 +639,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""
|
||||
from crewai.memory.utils import sanitize_scope_name
|
||||
|
||||
# Compute sanitized crew name for root_scope
|
||||
crew_name = sanitize_scope_name(self.name or "crew")
|
||||
crew_root_scope = f"/crew/{crew_name}"
|
||||
|
||||
@@ -747,7 +744,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
"""Validates that the crew ends with at most one asynchronous task."""
|
||||
final_async_task_count = 0
|
||||
|
||||
# Traverse tasks backward
|
||||
for task in reversed(self.tasks):
|
||||
if task.async_execution:
|
||||
final_async_task_count += 1
|
||||
@@ -837,7 +833,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if isinstance(task.context, list):
|
||||
for context_task in task.context:
|
||||
if id(context_task) not in task_indices:
|
||||
continue # Skip context tasks not in the main tasks list
|
||||
continue
|
||||
if task_indices[id(context_task)] > task_indices[id(task)]:
|
||||
raise ValueError(
|
||||
f"Task '{task.description}' has a context dependency "
|
||||
@@ -1040,7 +1036,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
# Ensure all background memory saves complete before returning
|
||||
if self._memory is not None and hasattr(self._memory, "drain_writes"):
|
||||
self._memory.drain_writes()
|
||||
clear_files(self.id)
|
||||
@@ -1592,7 +1587,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
|
||||
) -> list[BaseTool]:
|
||||
# Add delegation tools if agent allows delegation
|
||||
if hasattr(agent, "allow_delegation") and getattr(
|
||||
agent, "allow_delegation", False
|
||||
):
|
||||
@@ -1607,7 +1601,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
elif agent:
|
||||
tools = self._add_delegation_tools(task, tools)
|
||||
|
||||
# Add code execution tools if agent allows code execution
|
||||
if hasattr(agent, "allow_code_execution") and getattr(
|
||||
agent, "allow_code_execution", False
|
||||
):
|
||||
@@ -1627,7 +1620,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if agent and (hasattr(agent, "mcps") and getattr(agent, "mcps", None)):
|
||||
tools = self._add_mcp_tools(task, tools)
|
||||
|
||||
# Add memory tools if memory is available (agent or crew level)
|
||||
resolved_memory = getattr(agent, "memory", None) or self._memory
|
||||
if resolved_memory is not None:
|
||||
tools = self._add_memory_tools(tools, resolved_memory)
|
||||
@@ -1651,7 +1643,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def is_auto_injected(content_type: str) -> bool:
|
||||
return any(content_type.startswith(t) for t in supported_types)
|
||||
|
||||
# Only add read_file tool if there are files that need it
|
||||
files_needing_tool = {
|
||||
name: f
|
||||
for name, f in files.items()
|
||||
@@ -1676,17 +1667,14 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not new_tools:
|
||||
return existing_tools
|
||||
|
||||
# Create mapping of tool names to new tools
|
||||
new_tool_map = {sanitize_tool_name(tool.name): tool for tool in new_tools}
|
||||
|
||||
# Remove any existing tools that will be replaced
|
||||
tools = [
|
||||
tool
|
||||
for tool in existing_tools
|
||||
if sanitize_tool_name(tool.name) not in new_tool_map
|
||||
]
|
||||
|
||||
# Add all new tools
|
||||
tools.extend(new_tools)
|
||||
|
||||
return tools
|
||||
@@ -1699,7 +1687,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(task_agent, "get_delegation_tools"):
|
||||
delegation_tools = task_agent.get_delegation_tools(agents)
|
||||
# Cast delegation_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, delegation_tools)
|
||||
return tools
|
||||
|
||||
@@ -1739,7 +1726,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
) -> list[BaseTool]:
|
||||
if hasattr(agent, "get_code_execution_tools"):
|
||||
code_tools = agent.get_code_execution_tools()
|
||||
# Cast code_tools to the expected type for _merge_tools
|
||||
return self._merge_tools(tools, cast(list[BaseTool], code_tools))
|
||||
return tools
|
||||
|
||||
@@ -1844,7 +1830,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if not task_outputs:
|
||||
raise ValueError("No task outputs available to create crew output.")
|
||||
|
||||
# Filter out empty outputs and get the last valid one as the main output
|
||||
valid_outputs = [t for t in task_outputs if t.raw]
|
||||
if not valid_outputs:
|
||||
raise ValueError("No valid task outputs available to create crew output.")
|
||||
@@ -1972,13 +1957,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
for task in self.tasks:
|
||||
# description and expected_output might contain e.g. {topic}, {user_name}
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
# Scan agents for inputs
|
||||
for agent in self.agents:
|
||||
# role, goal, backstory might have placeholders like {role_detail}, etc.
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
@@ -2083,7 +2066,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
total_usage_metrics.add_usage_metrics(llm_usage)
|
||||
else:
|
||||
# fallback litellm
|
||||
if hasattr(agent, "_token_process"):
|
||||
token_sum = agent._token_process.get_summary()
|
||||
total_usage_metrics.add_usage_metrics(token_sum)
|
||||
@@ -2111,7 +2093,6 @@ class Crew(FlowTrackable, BaseModel):
|
||||
Uses concurrent.futures for concurrent execution.
|
||||
"""
|
||||
try:
|
||||
# Create LLM instance and ensure it's of type LLM for CrewEvaluator
|
||||
llm_instance = create_llm(eval_llm)
|
||||
if not llm_instance:
|
||||
raise ValueError("Failed to create LLM instance.")
|
||||
@@ -2270,13 +2251,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
def knowledge_reset(memory: Any) -> Any:
|
||||
return self.reset_knowledge(memory)
|
||||
|
||||
# Get knowledge for agents
|
||||
agent_knowledges = [
|
||||
getattr(agent, "knowledge", None)
|
||||
for agent in self.agents
|
||||
if getattr(agent, "knowledge", None) is not None
|
||||
]
|
||||
# Get knowledge for crew and agents
|
||||
crew_knowledge = getattr(self, "knowledge", None)
|
||||
crew_and_agent_knowledges = (
|
||||
[crew_knowledge] if crew_knowledge is not None else []
|
||||
|
||||
@@ -157,7 +157,6 @@ def prepare_task_execution(
|
||||
Raises:
|
||||
ValueError: If no agent is available for the task.
|
||||
"""
|
||||
# Handle replay skip
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
task_outputs.append(task.output)
|
||||
@@ -290,7 +289,6 @@ def prepare_kickoff(
|
||||
reset_emission_counter()
|
||||
reset_last_event_id()
|
||||
|
||||
# Normalize inputs to dict[str, Any] for internal processing
|
||||
normalized: dict[str, Any] | None = None
|
||||
if inputs is not None:
|
||||
if not isinstance(inputs, Mapping):
|
||||
@@ -331,15 +329,12 @@ def prepare_kickoff(
|
||||
crew._task_output_handler.reset()
|
||||
crew._logging_color = "bold_purple"
|
||||
|
||||
# Check for flow input files in baggage context (inherited from parent Flow)
|
||||
_flow_files = baggage.get_baggage("flow_input_files")
|
||||
flow_files: dict[str, Any] = _flow_files if isinstance(_flow_files, dict) else {}
|
||||
|
||||
if normalized is not None:
|
||||
# Extract file objects unpacked directly into inputs
|
||||
unpacked_files = _extract_files_from_inputs(normalized)
|
||||
|
||||
# Merge files: flow_files < input_files < unpacked_files (later takes precedence)
|
||||
all_files = {**flow_files, **(input_files or {}), **unpacked_files}
|
||||
if all_files:
|
||||
store_files(crew.id, all_files)
|
||||
@@ -347,7 +342,6 @@ def prepare_kickoff(
|
||||
crew._inputs = normalized
|
||||
crew._interpolate_inputs(normalized)
|
||||
else:
|
||||
# No inputs dict provided
|
||||
all_files = {**flow_files, **(input_files or {})}
|
||||
if all_files:
|
||||
store_files(crew.id, all_files)
|
||||
|
||||
@@ -144,9 +144,7 @@ if TYPE_CHECKING:
|
||||
ToolValidateInputErrorEvent,
|
||||
)
|
||||
|
||||
# Map every event class name → its module path for lazy loading
|
||||
_LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
# agent_events
|
||||
"AgentEvaluationCompletedEvent": "crewai.events.types.agent_events",
|
||||
"AgentEvaluationFailedEvent": "crewai.events.types.agent_events",
|
||||
"AgentEvaluationStartedEvent": "crewai.events.types.agent_events",
|
||||
@@ -156,7 +154,6 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"LiteAgentExecutionCompletedEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionErrorEvent": "crewai.events.types.agent_events",
|
||||
"LiteAgentExecutionStartedEvent": "crewai.events.types.agent_events",
|
||||
# checkpoint_events
|
||||
"CheckpointBaseEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointCompletedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointFailedEvent": "crewai.events.types.checkpoint_events",
|
||||
@@ -169,7 +166,6 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"CheckpointRestoreFailedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointRestoreStartedEvent": "crewai.events.types.checkpoint_events",
|
||||
"CheckpointStartedEvent": "crewai.events.types.checkpoint_events",
|
||||
# crew_events
|
||||
"CrewKickoffCompletedEvent": "crewai.events.types.crew_events",
|
||||
"CrewKickoffFailedEvent": "crewai.events.types.crew_events",
|
||||
"CrewKickoffStartedEvent": "crewai.events.types.crew_events",
|
||||
@@ -180,7 +176,6 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"CrewTrainCompletedEvent": "crewai.events.types.crew_events",
|
||||
"CrewTrainFailedEvent": "crewai.events.types.crew_events",
|
||||
"CrewTrainStartedEvent": "crewai.events.types.crew_events",
|
||||
# flow_events
|
||||
"FlowCreatedEvent": "crewai.events.types.flow_events",
|
||||
"FlowEvent": "crewai.events.types.flow_events",
|
||||
"FlowFinishedEvent": "crewai.events.types.flow_events",
|
||||
@@ -191,25 +186,20 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"MethodExecutionFailedEvent": "crewai.events.types.flow_events",
|
||||
"MethodExecutionFinishedEvent": "crewai.events.types.flow_events",
|
||||
"MethodExecutionStartedEvent": "crewai.events.types.flow_events",
|
||||
# knowledge_events
|
||||
"KnowledgeQueryCompletedEvent": "crewai.events.types.knowledge_events",
|
||||
"KnowledgeQueryFailedEvent": "crewai.events.types.knowledge_events",
|
||||
"KnowledgeQueryStartedEvent": "crewai.events.types.knowledge_events",
|
||||
"KnowledgeRetrievalCompletedEvent": "crewai.events.types.knowledge_events",
|
||||
"KnowledgeRetrievalStartedEvent": "crewai.events.types.knowledge_events",
|
||||
"KnowledgeSearchQueryFailedEvent": "crewai.events.types.knowledge_events",
|
||||
# llm_events
|
||||
"LLMCallCompletedEvent": "crewai.events.types.llm_events",
|
||||
"LLMCallFailedEvent": "crewai.events.types.llm_events",
|
||||
"LLMCallStartedEvent": "crewai.events.types.llm_events",
|
||||
"LLMStreamChunkEvent": "crewai.events.types.llm_events",
|
||||
# llm_guardrail_events
|
||||
"LLMGuardrailCompletedEvent": "crewai.events.types.llm_guardrail_events",
|
||||
"LLMGuardrailStartedEvent": "crewai.events.types.llm_guardrail_events",
|
||||
# logging_events
|
||||
"AgentLogsExecutionEvent": "crewai.events.types.logging_events",
|
||||
"AgentLogsStartedEvent": "crewai.events.types.logging_events",
|
||||
# mcp_events
|
||||
"MCPConfigFetchFailedEvent": "crewai.events.types.mcp_events",
|
||||
"MCPConnectionCompletedEvent": "crewai.events.types.mcp_events",
|
||||
"MCPConnectionFailedEvent": "crewai.events.types.mcp_events",
|
||||
@@ -217,7 +207,6 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"MCPToolExecutionCompletedEvent": "crewai.events.types.mcp_events",
|
||||
"MCPToolExecutionFailedEvent": "crewai.events.types.mcp_events",
|
||||
"MCPToolExecutionStartedEvent": "crewai.events.types.mcp_events",
|
||||
# memory_events
|
||||
"MemoryQueryCompletedEvent": "crewai.events.types.memory_events",
|
||||
"MemoryQueryFailedEvent": "crewai.events.types.memory_events",
|
||||
"MemoryQueryStartedEvent": "crewai.events.types.memory_events",
|
||||
@@ -227,24 +216,20 @@ _LAZY_EVENT_MAPPING: dict[str, str] = {
|
||||
"MemorySaveCompletedEvent": "crewai.events.types.memory_events",
|
||||
"MemorySaveFailedEvent": "crewai.events.types.memory_events",
|
||||
"MemorySaveStartedEvent": "crewai.events.types.memory_events",
|
||||
# reasoning_events
|
||||
"AgentReasoningCompletedEvent": "crewai.events.types.reasoning_events",
|
||||
"AgentReasoningFailedEvent": "crewai.events.types.reasoning_events",
|
||||
"AgentReasoningStartedEvent": "crewai.events.types.reasoning_events",
|
||||
"ReasoningEvent": "crewai.events.types.reasoning_events",
|
||||
# skill_events
|
||||
"SkillActivatedEvent": "crewai.events.types.skill_events",
|
||||
"SkillDiscoveryCompletedEvent": "crewai.events.types.skill_events",
|
||||
"SkillDiscoveryStartedEvent": "crewai.events.types.skill_events",
|
||||
"SkillEvent": "crewai.events.types.skill_events",
|
||||
"SkillLoadFailedEvent": "crewai.events.types.skill_events",
|
||||
"SkillLoadedEvent": "crewai.events.types.skill_events",
|
||||
# task_events
|
||||
"TaskCompletedEvent": "crewai.events.types.task_events",
|
||||
"TaskEvaluationEvent": "crewai.events.types.task_events",
|
||||
"TaskFailedEvent": "crewai.events.types.task_events",
|
||||
"TaskStartedEvent": "crewai.events.types.task_events",
|
||||
# tool_usage_events
|
||||
"ToolExecutionErrorEvent": "crewai.events.types.tool_usage_events",
|
||||
"ToolSelectionErrorEvent": "crewai.events.types.tool_usage_events",
|
||||
"ToolUsageErrorEvent": "crewai.events.types.tool_usage_events",
|
||||
|
||||
@@ -149,7 +149,6 @@ class CrewAIEventsBus:
|
||||
] = {}
|
||||
self._execution_plan_cache: dict[type[BaseEvent], ExecutionPlan] = {}
|
||||
self._console = ConsoleFormatter()
|
||||
# Lazy initialization flags - executor and loop created on first emit
|
||||
self._executor_initialized = False
|
||||
self._has_pending_events = False
|
||||
self._runtime_state: RuntimeState | None = None
|
||||
@@ -551,13 +550,10 @@ class CrewAIEventsBus:
|
||||
sync_handlers = self._sync_handlers.get(event_type, frozenset())
|
||||
async_handlers = self._async_handlers.get(event_type, frozenset())
|
||||
|
||||
# Skip executor initialization if no handlers exist for this event
|
||||
if not sync_handlers and not async_handlers:
|
||||
return None
|
||||
|
||||
# Lazily initialize executor and event loop only when handlers exist
|
||||
self._ensure_executor_initialized()
|
||||
# Track that we have pending events for flush optimization
|
||||
self._has_pending_events = True
|
||||
|
||||
if has_dependencies:
|
||||
@@ -684,7 +680,6 @@ class CrewAIEventsBus:
|
||||
Returns:
|
||||
True if all handlers completed, False if timeout occurred.
|
||||
"""
|
||||
# Skip flush entirely if no events were ever emitted
|
||||
if not self._has_pending_events:
|
||||
return True
|
||||
|
||||
@@ -698,7 +693,6 @@ class CrewAIEventsBus:
|
||||
|
||||
done, not_done = wait_futures(futures_to_wait, timeout=timeout)
|
||||
|
||||
# Check for exceptions in completed futures
|
||||
errors = [
|
||||
future.exception() for future in done if future.exception() is not None
|
||||
]
|
||||
@@ -847,7 +841,6 @@ class CrewAIEventsBus:
|
||||
|
||||
with self._rwlock.w_locked():
|
||||
self._shutting_down = True
|
||||
# Check if executor was ever initialized (lazy init optimization)
|
||||
if not self._executor_initialized:
|
||||
return
|
||||
loop = getattr(self, "_loop", None)
|
||||
|
||||
@@ -154,12 +154,9 @@ class EventListener(BaseEventListener):
|
||||
self._initialized = True
|
||||
self.formatter = ConsoleFormatter(verbose=True)
|
||||
|
||||
# Initialize trace listener with formatter for memory event handling
|
||||
trace_listener = TraceCollectionListener()
|
||||
trace_listener.formatter = self.formatter
|
||||
|
||||
# ----------- CREW EVENTS -----------
|
||||
|
||||
def setup_listeners(self, crewai_event_bus: CrewAIEventsBus) -> None:
|
||||
|
||||
@crewai_event_bus.on(CCEnvEvent)
|
||||
@@ -187,7 +184,6 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
def on_crew_completed(source: Any, event: CrewKickoffCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
final_string_output = event.output.raw
|
||||
self._telemetry.end_crew(source, final_string_output)
|
||||
|
||||
@@ -231,8 +227,6 @@ class EventListener(BaseEventListener):
|
||||
event.model,
|
||||
)
|
||||
|
||||
# ----------- TASK EVENTS -----------
|
||||
|
||||
def get_task_name(source: Any) -> str | None:
|
||||
return (
|
||||
source.name
|
||||
@@ -252,12 +246,10 @@ class EventListener(BaseEventListener):
|
||||
|
||||
@crewai_event_bus.on(TaskCompletedEvent)
|
||||
def on_task_completed(source: Any, event: TaskCompletedEvent) -> None:
|
||||
# Handle telemetry
|
||||
span = self.execution_spans.pop(source, None)
|
||||
if span:
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = get_task_name(source)
|
||||
self.formatter.handle_task_status(
|
||||
source.id, source.agent.role, "completed", task_name
|
||||
@@ -270,15 +262,11 @@ class EventListener(BaseEventListener):
|
||||
if source.agent and source.agent.crew:
|
||||
self._telemetry.task_ended(span, source, source.agent.crew)
|
||||
|
||||
# Pass task name if it exists
|
||||
task_name = get_task_name(source)
|
||||
self.formatter.handle_task_status(
|
||||
source.id, source.agent.role, "failed", task_name
|
||||
)
|
||||
|
||||
# ----------- AGENT EVENTS -----------
|
||||
# ----------- LITE AGENT EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LiteAgentExecutionStartedEvent)
|
||||
def on_lite_agent_execution_started(
|
||||
_: Any, event: LiteAgentExecutionStartedEvent
|
||||
@@ -309,8 +297,6 @@ class EventListener(BaseEventListener):
|
||||
**event.agent_info,
|
||||
)
|
||||
|
||||
# ----------- FLOW EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(FlowCreatedEvent)
|
||||
def on_flow_created(_: Any, event: FlowCreatedEvent) -> None:
|
||||
self._telemetry.flow_creation_span(event.flow_name)
|
||||
@@ -374,7 +360,6 @@ class EventListener(BaseEventListener):
|
||||
"paused",
|
||||
)
|
||||
|
||||
# ----------- HUMAN FEEDBACK EVENTS -----------
|
||||
@crewai_event_bus.on(HumanFeedbackRequestedEvent)
|
||||
def on_human_feedback_requested(
|
||||
_: Any, event: HumanFeedbackRequestedEvent
|
||||
@@ -401,7 +386,6 @@ class EventListener(BaseEventListener):
|
||||
outcome=event.outcome,
|
||||
)
|
||||
|
||||
# ----------- TOOL USAGE EVENTS -----------
|
||||
@crewai_event_bus.on(ToolUsageStartedEvent)
|
||||
def on_tool_usage_started(source: Any, event: ToolUsageStartedEvent) -> None:
|
||||
if isinstance(source, LLM):
|
||||
@@ -443,8 +427,6 @@ class EventListener(BaseEventListener):
|
||||
event.run_attempts,
|
||||
)
|
||||
|
||||
# ----------- LLM EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMCallStartedEvent)
|
||||
def on_llm_call_started(_: Any, event: LLMCallStartedEvent) -> None:
|
||||
self.text_stream = StringIO()
|
||||
@@ -472,8 +454,6 @@ class EventListener(BaseEventListener):
|
||||
event.call_type,
|
||||
)
|
||||
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(LLMGuardrailStartedEvent)
|
||||
def on_llm_guardrail_started(_: Any, event: LLMGuardrailStartedEvent) -> None:
|
||||
guardrail_str = str(event.guardrail)
|
||||
@@ -556,8 +536,6 @@ class EventListener(BaseEventListener):
|
||||
) -> None:
|
||||
self.formatter.handle_knowledge_search_query_failed(event.error)
|
||||
|
||||
# ----------- REASONING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentReasoningStartedEvent)
|
||||
def on_agent_reasoning_started(
|
||||
_: Any, event: AgentReasoningStartedEvent
|
||||
@@ -580,8 +558,6 @@ class EventListener(BaseEventListener):
|
||||
event.error,
|
||||
)
|
||||
|
||||
# ----------- OBSERVATION EVENTS (Plan-and-Execute) -----------
|
||||
|
||||
@crewai_event_bus.on(StepObservationStartedEvent)
|
||||
def on_step_observation_started(
|
||||
_: Any, event: StepObservationStartedEvent
|
||||
@@ -640,8 +616,6 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
self._telemetry.feature_usage_span("planning:goal_achieved_early")
|
||||
|
||||
# ----------- SKILL EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(SkillDiscoveryCompletedEvent)
|
||||
def on_skill_discovery(_: Any, event: SkillDiscoveryCompletedEvent) -> None:
|
||||
self._telemetry.feature_usage_span("skill:discovery")
|
||||
@@ -658,8 +632,6 @@ class EventListener(BaseEventListener):
|
||||
def on_skill_activated(_: Any, event: SkillActivatedEvent) -> None:
|
||||
self._telemetry.feature_usage_span("skill:activated")
|
||||
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(AgentLogsStartedEvent)
|
||||
def on_agent_logs_started(_: Any, event: AgentLogsStartedEvent) -> None:
|
||||
self.formatter.handle_agent_logs_started(
|
||||
@@ -702,7 +674,6 @@ class EventListener(BaseEventListener):
|
||||
def on_a2a_conversation_started(
|
||||
_: Any, event: A2AConversationStartedEvent
|
||||
) -> None:
|
||||
# Store A2A agent name for display in conversation tree
|
||||
if event.a2a_agent_name:
|
||||
self.formatter._current_a2a_agent_name = event.a2a_agent_name
|
||||
|
||||
@@ -757,8 +728,6 @@ class EventListener(BaseEventListener):
|
||||
event.poll_count,
|
||||
)
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
@crewai_event_bus.on(MCPConnectionStartedEvent)
|
||||
def on_mcp_connection_started(_: Any, event: MCPConnectionStartedEvent) -> None:
|
||||
self.formatter.handle_mcp_connection_started(
|
||||
@@ -833,8 +802,6 @@ class EventListener(BaseEventListener):
|
||||
)
|
||||
self._telemetry.feature_usage_span("mcp:tool_execution_failed")
|
||||
|
||||
# ----------- MEMORY TELEMETRY -----------
|
||||
|
||||
@crewai_event_bus.on(MemorySaveCompletedEvent)
|
||||
def on_memory_save_completed(_: Any, event: MemorySaveCompletedEvent) -> None:
|
||||
self._telemetry.feature_usage_span("memory:save")
|
||||
|
||||
@@ -403,7 +403,6 @@ class TraceBatchManager:
|
||||
if self.is_current_batch_ephemeral:
|
||||
self.ephemeral_trace_url = return_link
|
||||
|
||||
# Create a properly formatted message with URL on its own line
|
||||
message_parts = [
|
||||
f"✅ Trace batch finalized with session ID: {self.trace_batch_id}",
|
||||
"",
|
||||
|
||||
@@ -474,7 +474,6 @@ class TraceCollectionListener(BaseEventListener):
|
||||
) -> None:
|
||||
self._handle_action_event("agent_reasoning_failed", source, event)
|
||||
|
||||
# Observation events (Plan-and-Execute)
|
||||
@event_bus.on(StepObservationStartedEvent)
|
||||
def on_step_observation_started(
|
||||
source: Any, event: StepObservationStartedEvent
|
||||
|
||||
@@ -526,7 +526,6 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
response = input().strip().lower()
|
||||
result[0] = response in ["y", "yes"]
|
||||
except (EOFError, KeyboardInterrupt, OSError, LookupError):
|
||||
# Handle all input-related errors silently
|
||||
result[0] = False
|
||||
|
||||
ctx = contextvars.copy_context()
|
||||
@@ -540,7 +539,6 @@ def prompt_user_for_trace_viewing(timeout_seconds: int = 20) -> bool:
|
||||
return result[0]
|
||||
|
||||
except Exception:
|
||||
# Suppress any warnings or errors and assume "no"
|
||||
return False
|
||||
|
||||
|
||||
|
||||
@@ -726,11 +726,6 @@ class A2AContentTypeNegotiatedEvent(A2AEventBase):
|
||||
metadata: dict[str, Any] | None = None
|
||||
|
||||
|
||||
# -----------------------------------------------------------------------------
|
||||
# Context Lifecycle Events
|
||||
# -----------------------------------------------------------------------------
|
||||
|
||||
|
||||
class A2AContextCreatedEvent(A2AEventBase):
|
||||
"""Event emitted when an A2A context is created.
|
||||
|
||||
|
||||
@@ -66,7 +66,6 @@ class AgentExecutionErrorEvent(BaseEvent):
|
||||
return self
|
||||
|
||||
|
||||
# New event classes for LiteAgent
|
||||
class LiteAgentExecutionStartedEvent(BaseEvent):
|
||||
"""Event emitted when a LiteAgent starts executing"""
|
||||
|
||||
@@ -94,7 +93,6 @@ class LiteAgentExecutionErrorEvent(BaseEvent):
|
||||
type: Literal["lite_agent_execution_error"] = "lite_agent_execution_error"
|
||||
|
||||
|
||||
# Agent Eval events
|
||||
class AgentEvaluationStartedEvent(BaseEvent):
|
||||
agent_id: str
|
||||
agent_role: str
|
||||
|
||||
@@ -41,7 +41,6 @@ class ToolUsageEvent(BaseEvent):
|
||||
|
||||
super().__init__(**data)
|
||||
|
||||
# Set fingerprint data from the agent
|
||||
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
@@ -101,7 +100,6 @@ class ToolExecutionErrorEvent(BaseEvent):
|
||||
|
||||
def __init__(self, **data: Any) -> None:
|
||||
super().__init__(**data)
|
||||
# Set fingerprint data from the agent
|
||||
if self.agent and hasattr(self.agent, "fingerprint") and self.agent.fingerprint:
|
||||
self.source_fingerprint = self.agent.fingerprint.uuid_str
|
||||
self.source_type = "agent"
|
||||
|
||||
@@ -184,7 +184,6 @@ To enable tracing, do any one of these:
|
||||
"""Print to console. Simplified to only handle panel-based output."""
|
||||
if should_suppress_console_output():
|
||||
return
|
||||
# Skip blank lines during streaming
|
||||
if len(args) == 0 and self._is_streaming:
|
||||
return
|
||||
self.console.print(*args, **kwargs)
|
||||
@@ -874,8 +873,6 @@ To enable tracing, do any one of these:
|
||||
)
|
||||
self.print_panel(error_content, "❌ Search Error", "red")
|
||||
|
||||
# ----------- AGENT REASONING EVENTS -----------
|
||||
|
||||
def handle_reasoning_started(
|
||||
self,
|
||||
attempt: int,
|
||||
@@ -936,8 +933,6 @@ To enable tracing, do any one of these:
|
||||
)
|
||||
self.print_panel(error_content, "❌ Reasoning Error", "red")
|
||||
|
||||
# ----------- OBSERVATION EVENTS (Plan-and-Execute) -----------
|
||||
|
||||
def handle_observation_started(
|
||||
self,
|
||||
agent_role: str,
|
||||
@@ -1082,8 +1077,6 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(content, "🎯 Early Goal Achievement", "green")
|
||||
|
||||
# ----------- AGENT LOGGING EVENTS -----------
|
||||
|
||||
def handle_agent_logs_started(
|
||||
self,
|
||||
agent_role: str,
|
||||
@@ -1096,7 +1089,6 @@ To enable tracing, do any one of these:
|
||||
|
||||
agent_role = agent_role.partition("\n")[0]
|
||||
|
||||
# Create panel content
|
||||
content = Text()
|
||||
content.append("Agent: ", style="white")
|
||||
content.append(f"{agent_role}", style="bright_green bold")
|
||||
@@ -1105,7 +1097,6 @@ To enable tracing, do any one of these:
|
||||
content.append("\n\nTask: ", style="white")
|
||||
content.append(f"{task_description}", style="bright_green")
|
||||
|
||||
# Create and display the panel
|
||||
agent_panel = Panel(
|
||||
content,
|
||||
title="🤖 Agent Started",
|
||||
@@ -1132,7 +1123,6 @@ To enable tracing, do any one of these:
|
||||
agent_role = agent_role.partition("\n")[0]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
# Create tool output content with better formatting
|
||||
output_text = str(formatted_answer.result)
|
||||
if len(output_text) > 2000:
|
||||
output_text = output_text[:1997] + "..."
|
||||
@@ -1144,7 +1134,6 @@ To enable tracing, do any one of these:
|
||||
padding=(1, 2),
|
||||
)
|
||||
|
||||
# Print all panels
|
||||
self.print(output_panel)
|
||||
self.print()
|
||||
|
||||
@@ -1463,7 +1452,6 @@ To enable tracing, do any one of these:
|
||||
crewai_agent_role = self._pending_a2a_agent_role or agent_role or "User"
|
||||
message_content = self._pending_a2a_message or ""
|
||||
|
||||
# Determine status styling
|
||||
if status == "completed":
|
||||
style = "green"
|
||||
status_indicator = "✓"
|
||||
@@ -1505,7 +1493,6 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(content, f"💬 A2A Turn #{turn_number}", style)
|
||||
|
||||
# Clear pending state
|
||||
self._pending_a2a_message = None
|
||||
self._pending_a2a_agent_role = None
|
||||
self._pending_a2a_turn_number = None
|
||||
@@ -1544,14 +1531,11 @@ To enable tracing, do any one of these:
|
||||
|
||||
self.print_panel(content, "❌ A2A Failed", "red")
|
||||
|
||||
# Reset state
|
||||
self.current_a2a_turn_count = 0
|
||||
self._pending_a2a_message = None
|
||||
self._pending_a2a_agent_role = None
|
||||
self._pending_a2a_turn_number = None
|
||||
|
||||
# ----------- MCP EVENTS -----------
|
||||
|
||||
def handle_mcp_connection_started(
|
||||
self,
|
||||
server_name: str,
|
||||
|
||||
@@ -338,10 +338,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
|
||||
self.state.todos = TodoList(items=todos)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Plan-and-Execute: Component Initialization
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
def _ensure_step_executor(self) -> Any:
|
||||
"""Lazily create the StepExecutor (avoids circular imports)."""
|
||||
if self._step_executor is None:
|
||||
@@ -498,10 +494,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
dependency_results=dependency_results,
|
||||
)
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Plan-and-Execute: New Observation-Driven Flow Methods
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@router("step_executed")
|
||||
def observe_step_result(
|
||||
self,
|
||||
@@ -537,7 +529,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
|
||||
self.state.observations[current_todo.step_number] = observation
|
||||
|
||||
# Log observation for debugging
|
||||
self.state.execution_log.append(
|
||||
{
|
||||
"type": "observation",
|
||||
@@ -570,8 +561,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
return "step_observed_medium"
|
||||
return "step_observed_low"
|
||||
|
||||
# -- Low effort: observe → mark complete → continue (no replan/refine) --
|
||||
|
||||
@router("step_observed_low")
|
||||
def handle_step_observed_low(
|
||||
self,
|
||||
@@ -643,8 +632,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
|
||||
return "continue_plan"
|
||||
|
||||
# -- Medium effort: observe → replan on failure only (no refine) --
|
||||
|
||||
@router("step_observed_medium")
|
||||
def handle_step_observed_medium(
|
||||
self,
|
||||
@@ -711,8 +698,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
)
|
||||
return "continue_plan"
|
||||
|
||||
# -- High effort: full observation pipeline (existing behavior) --
|
||||
|
||||
@router("step_observed_high")
|
||||
def decide_next_action(
|
||||
self,
|
||||
@@ -776,7 +761,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
self.state.last_replan_reason = "Step did not complete successfully"
|
||||
return "replan_now"
|
||||
|
||||
# Plan still valid but needs refinement
|
||||
if observation.remaining_plan_still_valid and observation.suggested_refinements:
|
||||
self.state.todos.mark_completed(
|
||||
current_todo.step_number, result=current_todo.result
|
||||
@@ -788,7 +772,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
)
|
||||
return "refine_and_continue"
|
||||
|
||||
# Plan still valid, no refinements needed — just continue
|
||||
self.state.todos.mark_completed(
|
||||
current_todo.step_number, result=current_todo.result
|
||||
)
|
||||
@@ -860,7 +843,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
completed = self.state.todos.get_completed_todos()
|
||||
remaining = self.state.todos.get_pending_todos()
|
||||
|
||||
# Emit goal achieved early event
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
event=GoalAchievedEarlyEvent(
|
||||
@@ -903,7 +885,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
reason = self.state.last_replan_reason or "Dynamic replan triggered"
|
||||
completed = self.state.todos.get_completed_todos()
|
||||
|
||||
# Emit replan triggered event
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
event=PlanReplanTriggeredEvent(
|
||||
@@ -924,10 +905,6 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
return "has_todos"
|
||||
return "all_todos_complete"
|
||||
|
||||
# -------------------------------------------------------------------------
|
||||
# Todo-Driven Execution Flow
|
||||
# -------------------------------------------------------------------------
|
||||
|
||||
@router(generate_plan)
|
||||
def check_todos_available(
|
||||
self,
|
||||
@@ -973,11 +950,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
return "needs_replan"
|
||||
|
||||
if len(ready) == 1:
|
||||
# Mark the single ready todo as running
|
||||
self.state.todos.mark_running(ready[0].step_number)
|
||||
return "single_todo_ready"
|
||||
|
||||
# Multiple todos ready - can parallelize
|
||||
return "multiple_todos_ready"
|
||||
|
||||
@router("single_todo_ready")
|
||||
@@ -1017,10 +992,9 @@ class AgentExecutor(Flow[AgentExecutorState], BaseAgentExecutor):
|
||||
step_timeout=self._get_step_timeout(),
|
||||
)
|
||||
|
||||
# Store result on the todo (do NOT mark completed — observation decides)
|
||||
# Do NOT mark completed here — observation logic decides
|
||||
current.result = result.result
|
||||
|
||||
# Log to audit trail
|
||||
self.state.execution_log.append(
|
||||
{
|
||||
"type": "step_execution",
|
||||
|
||||
@@ -235,19 +235,15 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if isinstance(content, str):
|
||||
messages.append(content)
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle message list format
|
||||
messages.extend(
|
||||
msg["content"]
|
||||
for msg in content
|
||||
if isinstance(msg, dict) and "content" in msg
|
||||
)
|
||||
|
||||
# Simple n-gram based similarity detection
|
||||
# For a more robust implementation, consider using embedding-based similarity
|
||||
# NOTE: Uses simple n-gram similarity; embedding-based would be more robust
|
||||
for i in range(len(messages) - 2):
|
||||
for j in range(i + 1, len(messages) - 1):
|
||||
# Check for repeated patterns (simplistic approach)
|
||||
# A more sophisticated approach would use semantic similarity
|
||||
similarity = self._calculate_text_similarity(messages[i], messages[j])
|
||||
if similarity > 0.7: # Arbitrary threshold
|
||||
loop_details.append(
|
||||
@@ -285,7 +281,6 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
if isinstance(content, str):
|
||||
call_lengths.append(len(content))
|
||||
elif isinstance(content, list) and len(content) > 0:
|
||||
# Handle message list format
|
||||
total_length = 0
|
||||
for msg in content:
|
||||
if isinstance(msg, dict) and "content" in msg:
|
||||
@@ -342,10 +337,9 @@ Identify any inefficient reasoning patterns and provide specific suggestions for
|
||||
x = np.arange(len(values))
|
||||
y = np.array(values)
|
||||
|
||||
# Simple linear regression
|
||||
slope = np.polyfit(x, y, 1)[0]
|
||||
|
||||
# Normalize slope to -1 to 1 range
|
||||
# Normalize slope to [-1, 1] using full data range as denominator
|
||||
max_possible_slope = max(values) - min(values)
|
||||
if max_possible_slope > 0:
|
||||
normalized_slope = slope / max_possible_slope
|
||||
|
||||
@@ -89,7 +89,6 @@ class ConsoleProvider:
|
||||
HumanFeedbackRequestedEvent,
|
||||
)
|
||||
|
||||
# Emit feedback requested event
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
HumanFeedbackRequestedEvent(
|
||||
@@ -102,7 +101,6 @@ class ConsoleProvider:
|
||||
),
|
||||
)
|
||||
|
||||
# Pause live updates during human input
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
@@ -110,14 +108,12 @@ class ConsoleProvider:
|
||||
console = formatter.console
|
||||
|
||||
if self.verbose:
|
||||
# Display output with formatting using Rich console
|
||||
console.print("\n" + "═" * 50, style="bold cyan")
|
||||
console.print(" OUTPUT FOR REVIEW", style="bold cyan")
|
||||
console.print("═" * 50 + "\n", style="bold cyan")
|
||||
console.print(context.method_output)
|
||||
console.print("\n" + "═" * 50 + "\n", style="bold cyan")
|
||||
|
||||
# Show message and prompt for feedback
|
||||
console.print(context.message, style="yellow")
|
||||
console.print(
|
||||
"(Press Enter to skip, or type your feedback)\n", style="cyan"
|
||||
@@ -125,7 +121,6 @@ class ConsoleProvider:
|
||||
|
||||
feedback = input("Your feedback: ").strip()
|
||||
|
||||
# Emit feedback received event
|
||||
crewai_event_bus.emit(
|
||||
flow,
|
||||
HumanFeedbackReceivedEvent(
|
||||
@@ -139,7 +134,6 @@ class ConsoleProvider:
|
||||
|
||||
return feedback
|
||||
finally:
|
||||
# Resume live updates
|
||||
formatter.resume_live_updates()
|
||||
|
||||
def request_input(
|
||||
@@ -170,7 +164,6 @@ class ConsoleProvider:
|
||||
"""
|
||||
from crewai.events.event_listener import event_listener
|
||||
|
||||
# Pause live updates during human input
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
@@ -191,5 +184,4 @@ class ConsoleProvider:
|
||||
|
||||
return response
|
||||
finally:
|
||||
# Resume live updates
|
||||
formatter.resume_live_updates()
|
||||
|
||||
@@ -879,18 +879,15 @@ class FlowMeta(ModelMetaclass):
|
||||
routers = set()
|
||||
|
||||
for attr_name, attr_value in namespace.items():
|
||||
# Check for any flow-related attributes
|
||||
if (
|
||||
hasattr(attr_value, "__is_flow_method__")
|
||||
or hasattr(attr_value, "__is_start_method__")
|
||||
or hasattr(attr_value, "__trigger_methods__")
|
||||
or hasattr(attr_value, "__is_router__")
|
||||
):
|
||||
# Register start methods
|
||||
if hasattr(attr_value, "__is_start_method__"):
|
||||
start_methods.append(attr_name)
|
||||
|
||||
# Register listeners and routers
|
||||
if (
|
||||
hasattr(attr_value, "__trigger_methods__")
|
||||
and attr_value.__trigger_methods__ is not None
|
||||
@@ -913,14 +910,13 @@ class FlowMeta(ModelMetaclass):
|
||||
and attr_value.__is_router__
|
||||
):
|
||||
routers.add(attr_name)
|
||||
# First check for explicit __router_paths__ (set by @human_feedback(emit=[...]))
|
||||
# Explicit __router_paths__ set by @human_feedback(emit=[...]) takes priority over source analysis
|
||||
if (
|
||||
hasattr(attr_value, "__router_paths__")
|
||||
and attr_value.__router_paths__
|
||||
):
|
||||
router_paths[attr_name] = attr_value.__router_paths__
|
||||
else:
|
||||
# Fall back to source code analysis for @router methods
|
||||
possible_returns = get_possible_return_constants(attr_value)
|
||||
if possible_returns:
|
||||
router_paths[attr_name] = possible_returns
|
||||
@@ -934,7 +930,6 @@ class FlowMeta(ModelMetaclass):
|
||||
and attr_value.__is_router__
|
||||
):
|
||||
routers.add(attr_name)
|
||||
# Get router paths from the decorator attribute
|
||||
if (
|
||||
hasattr(attr_value, "__router_paths__")
|
||||
and attr_value.__router_paths__
|
||||
@@ -1179,12 +1174,10 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
flow_name = sanitize_scope_name(self.name or self.__class__.__name__)
|
||||
self.memory = Memory(root_scope=f"/flow/{flow_name}")
|
||||
|
||||
# Register all flow-related methods
|
||||
for method_name in dir(self):
|
||||
if not method_name.startswith("_"):
|
||||
method = getattr(self, method_name)
|
||||
if is_flow_method(method):
|
||||
# Ensure method is bound to this instance
|
||||
if not hasattr(method, "__self__"):
|
||||
method = method.__get__(self, self.__class__)
|
||||
self._methods[method.__name__] = method
|
||||
@@ -1465,23 +1458,15 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
persistence = SQLiteFlowPersistence()
|
||||
|
||||
# Load pending feedback context and state
|
||||
loaded = persistence.load_pending_feedback(flow_id)
|
||||
if loaded is None:
|
||||
raise ValueError(f"No pending feedback found for flow_id: {flow_id}")
|
||||
|
||||
state_data, pending_context = loaded
|
||||
|
||||
# Create flow instance with persistence
|
||||
instance = cls(persistence=persistence, **kwargs)
|
||||
|
||||
# Restore state
|
||||
instance._initialize_state(state_data)
|
||||
|
||||
# Store pending context for resume
|
||||
instance._pending_feedback_context = pending_context
|
||||
|
||||
# Mark that we're resuming execution
|
||||
instance._is_execution_resuming = True
|
||||
|
||||
return instance
|
||||
@@ -1625,15 +1610,12 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if llm is None:
|
||||
llm = _deserialize_llm_from_context(context.llm)
|
||||
|
||||
# Determine outcome
|
||||
collapsed_outcome: str | None = None
|
||||
|
||||
if not feedback.strip():
|
||||
# Empty feedback
|
||||
if default_outcome:
|
||||
collapsed_outcome = default_outcome
|
||||
elif emit:
|
||||
# No default and no feedback - use first outcome
|
||||
collapsed_outcome = emit[0]
|
||||
elif emit:
|
||||
if llm is not None:
|
||||
@@ -1645,7 +1627,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
collapsed_outcome = emit[0]
|
||||
|
||||
# Create result
|
||||
result = HumanFeedbackResult(
|
||||
output=context.method_output,
|
||||
feedback=feedback,
|
||||
@@ -1655,7 +1636,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
metadata=context.metadata,
|
||||
)
|
||||
|
||||
# Store in flow instance
|
||||
self.human_feedback_history.append(result)
|
||||
self.last_human_feedback = result
|
||||
|
||||
@@ -1663,11 +1643,9 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
self._pending_feedback_context = None
|
||||
|
||||
# Clear pending feedback from persistence
|
||||
if self.persistence:
|
||||
self.persistence.clear_pending_feedback(context.flow_id)
|
||||
|
||||
# Emit feedback received event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
MethodExecutionFinishedEvent(
|
||||
@@ -1722,7 +1700,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
state_data=state_data,
|
||||
)
|
||||
|
||||
# Emit flow paused event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowPausedEvent(
|
||||
@@ -1735,7 +1712,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
emit=e.context.emit,
|
||||
),
|
||||
)
|
||||
# Return the pending exception instead of raising
|
||||
return e
|
||||
raise
|
||||
|
||||
@@ -1827,7 +1803,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if init_state is dict:
|
||||
return cast(T, {"id": str(uuid4())})
|
||||
|
||||
# Handle dictionary instance case
|
||||
if isinstance(init_state, dict):
|
||||
new_state = dict(init_state) # Copy to avoid mutations
|
||||
if "id" not in new_state:
|
||||
@@ -1928,27 +1903,22 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
"""
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, update with inputs
|
||||
# If inputs contains an id, use it (for restoring from persistence)
|
||||
# Otherwise preserve the current id or generate a new one
|
||||
# If inputs contains an id, use it (for restoring from persistence);
|
||||
# otherwise preserve the current id or generate a new one.
|
||||
current_id = self._state.get("id")
|
||||
inputs_has_id = "id" in inputs
|
||||
|
||||
# Update specified fields
|
||||
for k, v in inputs.items():
|
||||
self._state[k] = v
|
||||
|
||||
# Ensure ID is set: prefer inputs id, then current id, then generate
|
||||
if not inputs_has_id:
|
||||
if current_id:
|
||||
self._state["id"] = current_id
|
||||
elif "id" not in self._state:
|
||||
self._state["id"] = str(uuid4())
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, preserve existing fields unless overridden
|
||||
try:
|
||||
model = self._state
|
||||
# Get current state as dict
|
||||
if hasattr(model, "model_dump"):
|
||||
current_state = model.model_dump()
|
||||
elif hasattr(model, "dict"):
|
||||
@@ -1958,19 +1928,14 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
k: v for k, v in model.__dict__.items() if not k.startswith("_")
|
||||
}
|
||||
|
||||
# Create new state with preserved fields and updates
|
||||
new_state = {**current_state, **inputs}
|
||||
|
||||
# Create new instance with merged state
|
||||
model_class = type(model)
|
||||
if hasattr(model_class, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, model_class.model_validate(new_state))
|
||||
elif hasattr(model_class, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, model_class.parse_obj(new_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, model_class(**new_state))
|
||||
except ValidationError as e:
|
||||
raise ValueError(f"Invalid inputs for structured state: {e}") from e
|
||||
@@ -1987,26 +1952,20 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
ValueError: If validation fails for structured state
|
||||
TypeError: If state is neither BaseModel nor dictionary
|
||||
"""
|
||||
# When restoring from persistence, use the stored ID
|
||||
stored_id = stored_state.get("id")
|
||||
if not stored_id:
|
||||
raise ValueError("Stored state must have an 'id' field")
|
||||
|
||||
if isinstance(self._state, dict):
|
||||
# For dict states, update all fields from stored state
|
||||
self._state.clear()
|
||||
self._state.update(stored_state)
|
||||
elif isinstance(self._state, BaseModel):
|
||||
# For BaseModel states, create new instance with stored values
|
||||
model = self._state
|
||||
if hasattr(model, "model_validate"):
|
||||
# Pydantic v2
|
||||
self._state = cast(T, type(model).model_validate(stored_state))
|
||||
elif hasattr(model, "parse_obj"):
|
||||
# Pydantic v1
|
||||
self._state = cast(T, type(model).parse_obj(stored_state))
|
||||
else:
|
||||
# Fallback for other BaseModel implementations
|
||||
self._state = cast(T, type(model)(**stored_state))
|
||||
else:
|
||||
raise TypeError(f"State must be dict or BaseModel, got {type(self._state)}")
|
||||
@@ -2927,9 +2886,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
if current_trigger in router_results:
|
||||
# Find start methods triggered by this router result
|
||||
for method_name in self._start_methods:
|
||||
# Check if this start method is triggered by the current trigger
|
||||
if method_name in self._listeners:
|
||||
condition_data = self._listeners[method_name]
|
||||
should_trigger = False
|
||||
@@ -2941,15 +2898,13 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
should_trigger = current_trigger in all_methods
|
||||
|
||||
if should_trigger:
|
||||
# Execute conditional start method triggered by router result
|
||||
if method_name in self._completed_methods:
|
||||
# For cyclic re-execution, temporarily clear resumption flag
|
||||
# Cyclic re-execution: temporarily clear resumption flag so the method actually re-runs
|
||||
was_resuming = self._is_execution_resuming
|
||||
self._is_execution_resuming = False
|
||||
await self._execute_start_method(method_name)
|
||||
self._is_execution_resuming = was_resuming
|
||||
else:
|
||||
# First-time execution of conditional start
|
||||
await self._execute_start_method(method_name)
|
||||
|
||||
def _evaluate_condition(
|
||||
@@ -3191,7 +3146,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
listener_name, method
|
||||
)
|
||||
|
||||
# Execute listeners (and possibly routers) of this listener
|
||||
await self._execute_listeners(
|
||||
listener_name, listener_result, finished_event_id
|
||||
)
|
||||
@@ -3208,8 +3162,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
e._flow_listener_logged = True # type: ignore[attr-defined]
|
||||
raise
|
||||
|
||||
# ── User Input (self.ask) ────────────────────────────────────────
|
||||
|
||||
def _resolve_input_provider(self) -> InputProvider:
|
||||
"""Resolve the input provider using the priority chain.
|
||||
|
||||
@@ -3324,7 +3276,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
method_name = current_flow_method_name.get("unknown")
|
||||
|
||||
# Emit input requested event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowInputRequestedEvent(
|
||||
@@ -3336,7 +3287,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
|
||||
# Auto-checkpoint state before waiting
|
||||
self._checkpoint_state_for_ask()
|
||||
|
||||
provider = self._resolve_input_provider()
|
||||
@@ -3369,7 +3319,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
logger.debug("Input provider error in ask()", exc_info=True)
|
||||
raw = None
|
||||
|
||||
# Normalize provider response: str, InputResponse, or None
|
||||
response: str | None = None
|
||||
response_metadata: dict[str, Any] | None = None
|
||||
|
||||
@@ -3381,7 +3330,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
response = None
|
||||
|
||||
# Record in history
|
||||
self._input_history.append(
|
||||
{
|
||||
"message": message,
|
||||
@@ -3393,7 +3341,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
}
|
||||
)
|
||||
|
||||
# Emit input received event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
FlowInputReceivedEvent(
|
||||
@@ -3432,7 +3379,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
HumanFeedbackRequestedEvent,
|
||||
)
|
||||
|
||||
# Emit feedback requested event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanFeedbackRequestedEvent(
|
||||
@@ -3445,19 +3391,16 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
),
|
||||
)
|
||||
|
||||
# Pause live updates during human input
|
||||
formatter = event_listener.formatter
|
||||
formatter.pause_live_updates()
|
||||
|
||||
try:
|
||||
# Display output with formatting using centralized Rich console
|
||||
formatter.console.print("\n" + "═" * 50, style="bold cyan")
|
||||
formatter.console.print(" OUTPUT FOR REVIEW", style="bold cyan")
|
||||
formatter.console.print("═" * 50 + "\n", style="bold cyan")
|
||||
formatter.console.print(output)
|
||||
formatter.console.print("\n" + "═" * 50 + "\n", style="bold cyan")
|
||||
|
||||
# Show message and prompt for feedback
|
||||
formatter.console.print(message, style="yellow")
|
||||
formatter.console.print(
|
||||
"(Press Enter to skip, or type your feedback)\n", style="cyan"
|
||||
@@ -3465,7 +3408,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
feedback = input("Your feedback: ").strip()
|
||||
|
||||
# Emit feedback received event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
HumanFeedbackReceivedEvent(
|
||||
@@ -3479,7 +3421,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
|
||||
return feedback
|
||||
finally:
|
||||
# Resume live updates
|
||||
formatter.resume_live_updates()
|
||||
|
||||
def _collapse_to_outcome(
|
||||
@@ -3521,7 +3462,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
else:
|
||||
raise ValueError(f"Invalid llm type: {type(llm)}. Expected str or BaseLLM.")
|
||||
|
||||
# Dynamically create a Pydantic model with constrained outcomes
|
||||
outcomes_tuple = tuple(outcomes)
|
||||
|
||||
class FeedbackOutcome(BaseModel):
|
||||
@@ -3539,8 +3479,7 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
|
||||
try:
|
||||
# Try structured output first (function calling)
|
||||
# Note: LLM.call with response_model returns JSON string, not Pydantic model
|
||||
# NOTE: LLM.call with response_model returns JSON string, not a Pydantic model
|
||||
response = llm_instance.call(
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
response_model=FeedbackOutcome,
|
||||
@@ -3567,7 +3506,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
return outcomes[0]
|
||||
|
||||
except Exception as e:
|
||||
# Fallback to simple prompting if structured output fails
|
||||
logger.warning(
|
||||
f"Structured output failed, falling back to simple prompting: {e}"
|
||||
)
|
||||
@@ -3577,7 +3515,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
)
|
||||
response_clean = str(response).strip()
|
||||
|
||||
# Exact match (case-insensitive)
|
||||
for outcome in outcomes:
|
||||
if outcome.lower() == response_clean.lower():
|
||||
return outcome
|
||||
@@ -3593,7 +3530,6 @@ class Flow(BaseModel, Generic[T], metaclass=FlowMeta):
|
||||
if best_outcome is not None:
|
||||
return best_outcome
|
||||
|
||||
# Fallback to first outcome
|
||||
logger.warning(
|
||||
f"Could not match LLM response '{response_clean}' to outcomes {list(outcomes)}. "
|
||||
f"Falling back to first outcome: {outcomes[0]}"
|
||||
|
||||
@@ -68,5 +68,4 @@ class FlowConfig:
|
||||
self._input_provider = provider
|
||||
|
||||
|
||||
# Singleton instance
|
||||
flow_config = FlowConfig()
|
||||
|
||||
@@ -191,7 +191,6 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
True if crew reference detected, False otherwise.
|
||||
"""
|
||||
try:
|
||||
# Get the underlying function from wrapper
|
||||
func = method
|
||||
if hasattr(method, "_meth"):
|
||||
func = method._meth
|
||||
@@ -201,7 +200,6 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
source = inspect.getsource(func)
|
||||
source = textwrap.dedent(source)
|
||||
|
||||
# Patterns that indicate Crew usage
|
||||
crew_patterns = [
|
||||
r"\.crew\(\)", # .crew() method call
|
||||
r"Crew\s*\(", # Crew( instantiation
|
||||
@@ -215,7 +213,6 @@ def _detect_crew_reference(method: Any) -> bool:
|
||||
|
||||
return False
|
||||
except (OSError, TypeError):
|
||||
# Can't get source code - assume no crew reference
|
||||
return False
|
||||
|
||||
|
||||
@@ -231,7 +228,6 @@ def _extract_trigger_methods(method: Any) -> tuple[list[str], str | None]:
|
||||
trigger_methods: list[str] = []
|
||||
condition_type: str | None = None
|
||||
|
||||
# First try __trigger_methods__ (populated for simple conditions)
|
||||
if hasattr(method, "__trigger_methods__") and method.__trigger_methods__:
|
||||
trigger_methods = [str(m) for m in method.__trigger_methods__]
|
||||
|
||||
@@ -264,11 +260,9 @@ def _extract_router_paths(
|
||||
"""
|
||||
method_name = getattr(method, "__name__", "")
|
||||
|
||||
# First check if there are __router_paths__ on the method itself
|
||||
if hasattr(method, "__router_paths__") and method.__router_paths__:
|
||||
return [str(p) for p in method.__router_paths__]
|
||||
|
||||
# Then check the class-level registry
|
||||
if method_name in router_paths_registry:
|
||||
return [str(p) for p in router_paths_registry[method_name]]
|
||||
|
||||
@@ -330,7 +324,6 @@ def _generate_edges(
|
||||
"""
|
||||
edges: list[EdgeInfo] = []
|
||||
|
||||
# Generate edges from listeners (listen edges)
|
||||
for listener_name, condition_data in listeners.items():
|
||||
trigger_methods: list[str] = []
|
||||
|
||||
@@ -340,7 +333,6 @@ def _generate_edges(
|
||||
elif isinstance(condition_data, dict):
|
||||
trigger_methods = _extract_all_methods_from_condition(condition_data)
|
||||
|
||||
# Create edges from each trigger to the listener
|
||||
edges.extend(
|
||||
EdgeInfo(
|
||||
from_method=trigger,
|
||||
@@ -352,10 +344,8 @@ def _generate_edges(
|
||||
if trigger in all_methods
|
||||
)
|
||||
|
||||
# Generate edges from routers (route edges)
|
||||
for router_name, paths in router_paths.items():
|
||||
for path in paths:
|
||||
# Find listeners that listen to this path
|
||||
for listener_name, condition_data in listeners.items():
|
||||
path_triggers: list[str] = []
|
||||
|
||||
@@ -393,11 +383,10 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
"""
|
||||
state_type: type | None = None
|
||||
|
||||
# Check for _initial_state_t set by __class_getitem__
|
||||
# _initial_state_t is set by Flow.__class_getitem__
|
||||
if hasattr(flow_class, "_initial_state_t"):
|
||||
state_type = flow_class._initial_state_t
|
||||
|
||||
# Check initial_state class attribute
|
||||
if state_type is None and hasattr(flow_class, "initial_state"):
|
||||
initial_state = flow_class.initial_state
|
||||
if isinstance(initial_state, type) and issubclass(initial_state, BaseModel):
|
||||
@@ -405,7 +394,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
elif isinstance(initial_state, BaseModel):
|
||||
state_type = type(initial_state)
|
||||
|
||||
# Check __orig_bases__ for generic parameters
|
||||
if state_type is None and hasattr(flow_class, "__orig_bases__"):
|
||||
for base in flow_class.__orig_bases__:
|
||||
origin = get_origin(base)
|
||||
@@ -420,7 +408,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
if state_type is None or not issubclass(state_type, BaseModel):
|
||||
return None
|
||||
|
||||
# Extract fields from the Pydantic model
|
||||
fields: list[StateFieldInfo] = []
|
||||
try:
|
||||
model_fields = state_type.model_fields
|
||||
@@ -428,7 +415,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
field_type_str = "Any"
|
||||
if field_info.annotation is not None:
|
||||
field_type_str = str(field_info.annotation)
|
||||
# Clean up the type string
|
||||
field_type_str = field_type_str.replace("typing.", "")
|
||||
field_type_str = field_type_str.replace("<class '", "").replace(
|
||||
"'>", ""
|
||||
@@ -441,7 +427,6 @@ def _extract_state_schema(flow_class: type) -> StateSchemaInfo | None:
|
||||
and not callable(field_info.default)
|
||||
):
|
||||
try:
|
||||
# Try to serialize the default value
|
||||
default_value = field_info.default
|
||||
except Exception:
|
||||
default_value = str(field_info.default)
|
||||
@@ -474,7 +459,6 @@ def _detect_flow_inputs(flow_class: type) -> list[str]:
|
||||
"""
|
||||
inputs: list[str] = []
|
||||
|
||||
# Check for inputs in __init__ signature beyond standard Flow params
|
||||
try:
|
||||
init_method = flow_class.__init__ # type: ignore[misc]
|
||||
init_sig = inspect.signature(init_method)
|
||||
@@ -533,7 +517,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
f"Got {type(flow_class).__name__}"
|
||||
)
|
||||
|
||||
# Get class-level metadata set by FlowMeta
|
||||
start_methods: list[str] = getattr(flow_class, "_start_methods", [])
|
||||
listeners: dict[str, Any] = getattr(flow_class, "_listeners", {})
|
||||
routers: set[str] = getattr(flow_class, "_routers", set())
|
||||
@@ -541,7 +524,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
flow_class, "_router_paths", {}
|
||||
)
|
||||
|
||||
# Collect all flow methods
|
||||
methods: list[MethodInfo] = []
|
||||
all_method_names: set[str] = set()
|
||||
|
||||
@@ -554,7 +536,6 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
except AttributeError:
|
||||
continue
|
||||
|
||||
# Check if it's a flow method
|
||||
is_flow_method = (
|
||||
isinstance(attr, (FlowMethod, StartMethod, ListenMethod, RouterMethod))
|
||||
or hasattr(attr, "__is_flow_method__")
|
||||
@@ -568,21 +549,16 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
|
||||
all_method_names.add(attr_name)
|
||||
|
||||
# Get method type
|
||||
method_type = _get_method_type(attr_name, attr, start_methods, routers)
|
||||
|
||||
# Get trigger methods and condition type
|
||||
trigger_methods, condition_type = _extract_trigger_methods(attr)
|
||||
|
||||
# Get router paths if applicable
|
||||
router_paths_list: list[str] = []
|
||||
if method_type in ("router", "start_router"):
|
||||
router_paths_list = _extract_router_paths(attr, router_paths_registry)
|
||||
|
||||
# Check for human feedback
|
||||
has_hf = _has_human_feedback(attr)
|
||||
|
||||
# Check for crew reference
|
||||
has_crew = _detect_crew_reference(attr)
|
||||
|
||||
method_info = MethodInfo(
|
||||
@@ -596,16 +572,12 @@ def flow_structure(flow_class: type) -> FlowStructureInfo:
|
||||
)
|
||||
methods.append(method_info)
|
||||
|
||||
# Generate edges
|
||||
edges = _generate_edges(listeners, routers, router_paths_registry, all_method_names)
|
||||
|
||||
# Extract state schema
|
||||
state_schema = _extract_state_schema(flow_class)
|
||||
|
||||
# Detect inputs
|
||||
inputs = _detect_flow_inputs(flow_class)
|
||||
|
||||
# Get flow description from docstring
|
||||
description: str | None = None
|
||||
if flow_class.__doc__:
|
||||
description = flow_class.__doc__.strip()
|
||||
|
||||
@@ -339,7 +339,6 @@ def human_feedback(
|
||||
return "Content to review..."
|
||||
```
|
||||
"""
|
||||
# Validation at decoration time
|
||||
if emit is not None:
|
||||
if not llm:
|
||||
raise ValueError(
|
||||
@@ -359,8 +358,6 @@ def human_feedback(
|
||||
def decorator(func: F) -> F:
|
||||
"""Inner decorator that wraps the function."""
|
||||
|
||||
# -- HITL learning helpers (only used when learn=True) --------
|
||||
|
||||
def _get_hitl_prompt(key: str) -> str:
|
||||
"""Read a HITL prompt from the i18n translations."""
|
||||
from crewai.utilities.i18n import I18N_DEFAULT
|
||||
@@ -485,8 +482,6 @@ def human_feedback(
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
# -- Core feedback helpers ------------------------------------
|
||||
|
||||
def _build_feedback_context(
|
||||
flow_instance: Flow[Any], method_output: Any
|
||||
) -> tuple[Any, Any]:
|
||||
@@ -565,15 +560,12 @@ def human_feedback(
|
||||
raw_feedback: str,
|
||||
) -> HumanFeedbackResult | str:
|
||||
"""Process feedback and return result or outcome."""
|
||||
# Determine outcome
|
||||
collapsed_outcome: str | None = None
|
||||
|
||||
if not raw_feedback.strip():
|
||||
# Empty feedback
|
||||
if default_outcome:
|
||||
collapsed_outcome = default_outcome
|
||||
elif emit:
|
||||
# No default and no feedback - use first outcome
|
||||
collapsed_outcome = emit[0]
|
||||
elif emit:
|
||||
if llm is not None:
|
||||
@@ -585,7 +577,6 @@ def human_feedback(
|
||||
else:
|
||||
collapsed_outcome = emit[0]
|
||||
|
||||
# Create result
|
||||
result = HumanFeedbackResult(
|
||||
output=method_output,
|
||||
feedback=raw_feedback,
|
||||
@@ -595,7 +586,6 @@ def human_feedback(
|
||||
metadata=metadata or {},
|
||||
)
|
||||
|
||||
# Store in flow instance
|
||||
flow_instance.human_feedback_history.append(result)
|
||||
flow_instance.last_human_feedback = result
|
||||
|
||||
@@ -607,19 +597,17 @@ def human_feedback(
|
||||
return result
|
||||
|
||||
if asyncio.iscoroutinefunction(func):
|
||||
# Async wrapper
|
||||
|
||||
@wraps(func)
|
||||
async def async_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
method_output = await func(self, *args, **kwargs)
|
||||
|
||||
# Pre-review: apply past HITL lessons before human sees it
|
||||
if learn and getattr(self, "memory", None) is not None:
|
||||
method_output = _pre_review_with_lessons(self, method_output)
|
||||
|
||||
raw_feedback = await _request_feedback_async(self, method_output)
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
@@ -627,10 +615,10 @@ def human_feedback(
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
# Stash the real method output for final flow result when emit is set
|
||||
# (result is the collapsed outcome string for routing, but we want to
|
||||
# preserve the actual method output as the flow's final result)
|
||||
# Uses per-method dict for concurrency safety and to handle None returns
|
||||
# Stash the real method output for final flow result when emit is set:
|
||||
# result is the collapsed outcome string for routing, but we preserve the
|
||||
# actual method output as the flow's final result. Uses per-method dict for
|
||||
# concurrency safety and to handle None returns.
|
||||
if emit:
|
||||
self._human_feedback_method_outputs[func.__name__] = method_output
|
||||
|
||||
@@ -638,19 +626,17 @@ def human_feedback(
|
||||
|
||||
wrapper: Any = async_wrapper
|
||||
else:
|
||||
# Sync wrapper
|
||||
|
||||
@wraps(func)
|
||||
def sync_wrapper(self: Flow[Any], *args: Any, **kwargs: Any) -> Any:
|
||||
method_output = func(self, *args, **kwargs)
|
||||
|
||||
# Pre-review: apply past HITL lessons before human sees it
|
||||
if learn and getattr(self, "memory", None) is not None:
|
||||
method_output = _pre_review_with_lessons(self, method_output)
|
||||
|
||||
raw_feedback = _request_feedback(self, method_output)
|
||||
result = _process_feedback(self, method_output, raw_feedback)
|
||||
|
||||
# Distill: extract lessons from output + feedback, store in memory
|
||||
if (
|
||||
learn
|
||||
and getattr(self, "memory", None) is not None
|
||||
@@ -658,10 +644,10 @@ def human_feedback(
|
||||
):
|
||||
_distill_and_store_lessons(self, method_output, raw_feedback)
|
||||
|
||||
# Stash the real method output for final flow result when emit is set
|
||||
# (result is the collapsed outcome string for routing, but we want to
|
||||
# preserve the actual method output as the flow's final result)
|
||||
# Uses per-method dict for concurrency safety and to handle None returns
|
||||
# Stash the real method output for final flow result when emit is set:
|
||||
# result is the collapsed outcome string for routing, but we preserve the
|
||||
# actual method output as the flow's final result. Uses per-method dict for
|
||||
# concurrency safety and to handle None returns.
|
||||
if emit:
|
||||
self._human_feedback_method_outputs[func.__name__] = method_output
|
||||
|
||||
@@ -669,7 +655,6 @@ def human_feedback(
|
||||
|
||||
wrapper = sync_wrapper
|
||||
|
||||
# Preserve existing Flow decorator attributes
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
@@ -680,7 +665,7 @@ def human_feedback(
|
||||
if hasattr(func, attr):
|
||||
setattr(wrapper, attr, getattr(func, attr))
|
||||
|
||||
# Add human feedback specific attributes (create config inline to avoid race conditions)
|
||||
# Create config inline to avoid race conditions
|
||||
wrapper.__human_feedback_config__ = HumanFeedbackConfig(
|
||||
message=message,
|
||||
emit=emit,
|
||||
|
||||
@@ -44,7 +44,6 @@ if TYPE_CHECKING:
|
||||
logger = logging.getLogger(__name__)
|
||||
T = TypeVar("T")
|
||||
|
||||
# Constants for log messages
|
||||
LOG_MESSAGES: Final[dict[str, str]] = {
|
||||
"save_state": "Saving flow state to memory for ID: {}",
|
||||
"save_error": "Failed to persist state for method {}: {}",
|
||||
@@ -100,7 +99,6 @@ class PersistenceDecorator:
|
||||
if not flow_uuid:
|
||||
raise ValueError("Flow state must have an 'id' field for persistence")
|
||||
|
||||
# Log state saving only if verbose is True
|
||||
if verbose:
|
||||
PRINTER.print(
|
||||
LOG_MESSAGES["save_state"].format(flow_uuid), color="cyan"
|
||||
@@ -169,7 +167,6 @@ def persist(
|
||||
actual_persistence = persistence or SQLiteFlowPersistence()
|
||||
|
||||
if isinstance(target, type):
|
||||
# Class decoration
|
||||
original_init = target.__init__ # type: ignore[misc]
|
||||
|
||||
@functools.wraps(original_init)
|
||||
@@ -180,7 +177,7 @@ def persist(
|
||||
|
||||
target.__init__ = new_init # type: ignore[misc]
|
||||
|
||||
# Store original methods to preserve their decorators
|
||||
# Preserve original methods' decorators
|
||||
original_methods = {
|
||||
name: method
|
||||
for name, method in target.__dict__.items()
|
||||
@@ -194,10 +191,9 @@ def persist(
|
||||
)
|
||||
}
|
||||
|
||||
# Create wrapped versions of the methods that include persistence
|
||||
for name, method in original_methods.items():
|
||||
if asyncio.iscoroutinefunction(method):
|
||||
# Create a closure to capture the current name and method
|
||||
# Closure captures the current name and method
|
||||
def create_async_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@@ -215,7 +211,6 @@ def persist(
|
||||
|
||||
wrapped = create_async_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
@@ -226,10 +221,9 @@ def persist(
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
else:
|
||||
# Create a closure to capture the current name and method
|
||||
|
||||
def create_sync_wrapper(
|
||||
method_name: str, original_method: Callable[..., Any]
|
||||
) -> Callable[..., Any]:
|
||||
@@ -245,7 +239,6 @@ def persist(
|
||||
|
||||
wrapped = create_sync_wrapper(name, method)
|
||||
|
||||
# Preserve all original decorators and attributes
|
||||
for attr in [
|
||||
"__is_start_method__",
|
||||
"__trigger_methods__",
|
||||
@@ -256,11 +249,9 @@ def persist(
|
||||
setattr(wrapped, attr, getattr(method, attr))
|
||||
wrapped.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
# Update the class with the wrapped method
|
||||
setattr(target, name, wrapped)
|
||||
|
||||
return target
|
||||
# Method decoration
|
||||
method = target
|
||||
method.__is_flow_method__ = True # type: ignore[attr-defined]
|
||||
|
||||
|
||||
@@ -75,7 +75,6 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
sqlite3.connect(self.db_path, timeout=30) as conn,
|
||||
):
|
||||
conn.execute("PRAGMA journal_mode=WAL")
|
||||
# Main state table
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS flow_states (
|
||||
@@ -87,7 +86,6 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add index for faster UUID lookups
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_flow_states_uuid
|
||||
@@ -95,7 +93,6 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
"""
|
||||
)
|
||||
|
||||
# Pending feedback table for async HITL
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE TABLE IF NOT EXISTS pending_feedback (
|
||||
@@ -107,7 +104,6 @@ class SQLiteFlowPersistence(FlowPersistence):
|
||||
)
|
||||
"""
|
||||
)
|
||||
# Add index for faster UUID lookups on pending feedback
|
||||
conn.execute(
|
||||
"""
|
||||
CREATE INDEX IF NOT EXISTS idx_pending_feedback_uuid
|
||||
|
||||
@@ -175,7 +175,6 @@ def get_possible_return_constants(
|
||||
try:
|
||||
source = inspect.getsource(function)
|
||||
except OSError:
|
||||
# Can't get source code
|
||||
return None
|
||||
except Exception as e:
|
||||
if verbose:
|
||||
@@ -186,9 +185,7 @@ def get_possible_return_constants(
|
||||
return None
|
||||
|
||||
try:
|
||||
# Remove leading indentation
|
||||
source = textwrap.dedent(source)
|
||||
# Parse the source code into an AST
|
||||
code_ast = ast.parse(source)
|
||||
except IndentationError as e:
|
||||
if verbose:
|
||||
@@ -254,12 +251,10 @@ def get_possible_return_constants(
|
||||
|
||||
class VariableAssignmentVisitor(ast.NodeVisitor):
|
||||
def visit_Assign(self, node: ast.Assign) -> None:
|
||||
# Check if this assignment is assigning a dictionary literal to a variable
|
||||
if isinstance(node.value, ast.Dict) and len(node.targets) == 1:
|
||||
target = node.targets[0]
|
||||
if isinstance(target, ast.Name):
|
||||
var_name = target.id
|
||||
# Extract string values from the dictionary
|
||||
dict_values = [
|
||||
val.value
|
||||
for val in node.value.values
|
||||
@@ -328,13 +323,10 @@ def get_possible_return_constants(
|
||||
def visit_If(self, node: ast.If) -> None:
|
||||
self.generic_visit(node)
|
||||
|
||||
# Try to get the class context to infer state attribute values
|
||||
try:
|
||||
if hasattr(function, "__self__"):
|
||||
# Method is bound, get the class
|
||||
class_obj = function.__self__.__class__
|
||||
elif hasattr(function, "__qualname__") and "." in function.__qualname__:
|
||||
# Method is unbound but we can try to get class from module
|
||||
class_name = function.__qualname__.rsplit(".", 1)[0]
|
||||
if hasattr(function, "__globals__"):
|
||||
class_obj = function.__globals__.get(class_name)
|
||||
@@ -349,7 +341,6 @@ def get_possible_return_constants(
|
||||
class_source = textwrap.dedent(class_source)
|
||||
class_ast = ast.parse(class_source)
|
||||
|
||||
# Look for comparisons and assignments involving state attributes
|
||||
class StateAttributeVisitor(ast.NodeVisitor):
|
||||
def visit_Compare(self, node: ast.Compare) -> None:
|
||||
"""Find comparisons like: self.state.attr == "value" """
|
||||
@@ -370,7 +361,6 @@ def get_possible_return_constants(
|
||||
comparator.value
|
||||
)
|
||||
|
||||
# Also check right side
|
||||
for comparator in node.comparators:
|
||||
right_attr = get_attribute_chain(comparator)
|
||||
if (
|
||||
@@ -439,13 +429,11 @@ def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
visited: set[str] = set()
|
||||
pending_and_listeners: dict[str, set[str]] = {}
|
||||
|
||||
# Make all start methods at level 0
|
||||
for method_name, method in flow._methods.items():
|
||||
if hasattr(method, "__is_start_method__"):
|
||||
levels[method_name] = 0
|
||||
queue.append(method_name)
|
||||
|
||||
# Precompute listener dependencies
|
||||
or_listeners = defaultdict(list)
|
||||
and_listeners = defaultdict(set)
|
||||
for listener_name, condition_data in flow._listeners.items():
|
||||
@@ -463,7 +451,6 @@ def calculate_node_levels(flow: Any) -> dict[str, int]:
|
||||
elif condition_type == "AND":
|
||||
and_listeners[listener_name] = set(trigger_methods)
|
||||
|
||||
# Breadth-first traversal to assign levels
|
||||
while queue:
|
||||
current = queue.popleft()
|
||||
current_level = levels[current]
|
||||
|
||||
@@ -74,10 +74,8 @@ def clear_all_global_hooks() -> dict[str, tuple[int, int]]:
|
||||
|
||||
|
||||
__all__ = [
|
||||
# Context classes
|
||||
"LLMCallHookContext",
|
||||
"ToolCallHookContext",
|
||||
# Decorators
|
||||
"after_llm_call",
|
||||
"after_tool_call",
|
||||
"before_llm_call",
|
||||
@@ -87,19 +85,15 @@ __all__ = [
|
||||
"clear_all_global_hooks",
|
||||
"clear_all_llm_call_hooks",
|
||||
"clear_all_tool_call_hooks",
|
||||
# Clear hooks
|
||||
"clear_before_llm_call_hooks",
|
||||
"clear_before_tool_call_hooks",
|
||||
"get_after_llm_call_hooks",
|
||||
"get_after_tool_call_hooks",
|
||||
# Get hooks
|
||||
"get_before_llm_call_hooks",
|
||||
"get_before_tool_call_hooks",
|
||||
"register_after_llm_call_hook",
|
||||
"register_after_tool_call_hook",
|
||||
# LLM Hook registration
|
||||
"register_before_llm_call_hook",
|
||||
# Tool Hook registration
|
||||
"register_before_tool_call_hook",
|
||||
"unregister_after_llm_call_hook",
|
||||
"unregister_after_tool_call_hook",
|
||||
|
||||
@@ -79,18 +79,15 @@ class LLMCallHookContext:
|
||||
crew: Optional crew reference (for direct LLM calls when executor is None)
|
||||
"""
|
||||
if executor is not None:
|
||||
# Existing path: extract from executor
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
# Handle CrewAgentExecutor vs LiteAgent differences
|
||||
if hasattr(executor, "agent"):
|
||||
self.agent = executor.agent
|
||||
self.task = cast("CrewAgentExecutor", executor).task
|
||||
self.crew = cast("CrewAgentExecutor", executor).crew
|
||||
else:
|
||||
# LiteAgent case - is the agent itself, doesn't have task/crew
|
||||
self.agent = (
|
||||
executor.original_agent
|
||||
if hasattr(executor, "original_agent")
|
||||
@@ -99,7 +96,6 @@ class LLMCallHookContext:
|
||||
self.task = None
|
||||
self.crew = None
|
||||
else:
|
||||
# New path: direct LLM call with explicit parameters
|
||||
self.executor = None
|
||||
self.messages = messages or []
|
||||
self.llm = llm
|
||||
|
||||
@@ -116,7 +116,6 @@ class ToolCallHookContext:
|
||||
event_listener.formatter.resume_live_updates()
|
||||
|
||||
|
||||
# Global hook registries
|
||||
_before_tool_call_hooks: list[BeforeToolCallHookType | BeforeToolCallHookCallable] = []
|
||||
_after_tool_call_hooks: list[AfterToolCallHookType | AfterToolCallHookCallable] = []
|
||||
|
||||
|
||||
@@ -71,7 +71,6 @@ class BeforeLLMCallHookMethod:
|
||||
"""
|
||||
if obj is None:
|
||||
return self
|
||||
# Return bound method
|
||||
return lambda context: self._meth(obj, context)
|
||||
|
||||
|
||||
|
||||
@@ -62,7 +62,7 @@ def _resolve_knowledge_sources(value: Any) -> Any:
|
||||
return resolved
|
||||
|
||||
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false" # removes logging from fastembed
|
||||
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
||||
|
||||
|
||||
def _serialize_embedder_spec(value: Any) -> dict[str, Any] | None:
|
||||
|
||||
@@ -31,7 +31,6 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
@@ -101,7 +100,6 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
|
||||
@@ -16,7 +16,6 @@ try:
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
# Provide type stubs for when docling is not available
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
@@ -136,7 +135,6 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
else:
|
||||
raise FileNotFoundError(f"File not found: {local_path}")
|
||||
else:
|
||||
# this is an instance of Path
|
||||
processed_paths.append(path)
|
||||
return processed_paths
|
||||
|
||||
@@ -147,7 +145,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
[
|
||||
result.scheme in ("http", "https"),
|
||||
result.netloc,
|
||||
len(result.netloc.split(".")) >= 2, # Ensure domain has TLD
|
||||
len(result.netloc.split(".")) >= 2,
|
||||
]
|
||||
)
|
||||
except Exception:
|
||||
|
||||
@@ -12,8 +12,6 @@ from crewai.utilities.logger import Logger
|
||||
class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""A knowledge source that stores and queries Excel file content using embeddings."""
|
||||
|
||||
# override content to be a dict of file paths to sheet names to csv content
|
||||
|
||||
_logger: Logger = Logger(verbose=True)
|
||||
|
||||
source_type: Literal["excel"] = "excel"
|
||||
@@ -34,7 +32,6 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
v is None
|
||||
and info.data.get(
|
||||
@@ -59,7 +56,6 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
if self.file_paths is None:
|
||||
raise ValueError("Your source must be provided with a file_paths: []")
|
||||
|
||||
# Convert single path to list
|
||||
path_list: list[Path | str] = (
|
||||
[self.file_paths]
|
||||
if isinstance(self.file_paths, (str, Path))
|
||||
@@ -151,8 +147,6 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
Add Excel file content to the knowledge source, chunk it, compute embeddings,
|
||||
and save the embeddings.
|
||||
"""
|
||||
# Convert dictionary values to a single string if content is a dictionary
|
||||
# Updated to account for .xlsx workbooks with multiple tabs/sheets
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
|
||||
@@ -416,7 +416,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if v is None or isinstance(v, str):
|
||||
return v
|
||||
|
||||
# Check function signature
|
||||
sig = inspect.signature(v)
|
||||
if len(sig.parameters) != 1:
|
||||
raise ValueError(
|
||||
@@ -424,7 +423,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
f"but it accepts {len(sig.parameters)}"
|
||||
)
|
||||
|
||||
# Check return annotation if present
|
||||
if sig.return_annotation is not sig.empty:
|
||||
if sig.return_annotation == tuple[bool, Any]:
|
||||
return v
|
||||
@@ -493,7 +491,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
Returns:
|
||||
LiteAgentOutput: The result of the agent execution.
|
||||
"""
|
||||
# Inject memory tools once if memory is configured (mirrors Agent._prepare_kickoff)
|
||||
if self._memory is not None:
|
||||
from crewai.tools.memory_tools import create_memory_tools
|
||||
from crewai.utilities.string_utils import sanitize_tool_name
|
||||
@@ -507,7 +504,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
if memory_tools:
|
||||
self._parsed_tools = self._parsed_tools + parse_tools(memory_tools)
|
||||
|
||||
# Create agent info for event emission
|
||||
agent_info = {
|
||||
"id": self.id,
|
||||
"role": self.role,
|
||||
@@ -518,11 +514,9 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
}
|
||||
|
||||
try:
|
||||
# Reset state for this run
|
||||
self._iterations = 0
|
||||
self.tools_results = []
|
||||
|
||||
# Format messages for the LLM
|
||||
self._messages = self._format_messages(
|
||||
messages, response_format=response_format, input_files=input_files
|
||||
)
|
||||
@@ -539,7 +533,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
color="red",
|
||||
)
|
||||
handle_unknown_error(PRINTER, e, verbose=self.verbose)
|
||||
# Emit error event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionErrorEvent(
|
||||
@@ -623,7 +616,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
def _execute_core(
|
||||
self, agent_info: dict[str, Any], response_format: type[BaseModel] | None = None
|
||||
) -> LiteAgentOutput:
|
||||
# Emit event for agent execution start
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionStartedEvent(
|
||||
@@ -633,7 +625,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the agent using invoke loop
|
||||
active_response_format = response_format or self.response_format
|
||||
agent_finish = self._invoke_loop(response_model=active_response_format)
|
||||
if self._memory is not None:
|
||||
@@ -681,13 +672,11 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Calculate token usage metrics
|
||||
if isinstance(self.llm, BaseLLM):
|
||||
usage_metrics = self.llm.get_token_usage_summary()
|
||||
else:
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
|
||||
# Create output
|
||||
raw_output = (
|
||||
agent_finish.output.model_dump_json()
|
||||
if isinstance(agent_finish.output, BaseModel)
|
||||
@@ -701,7 +690,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
messages=self._messages,
|
||||
)
|
||||
|
||||
# Process guardrail if set
|
||||
if self._guardrail is not None:
|
||||
guardrail_result = process_guardrail(
|
||||
output=output,
|
||||
@@ -734,7 +722,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
return self._execute_core(agent_info=agent_info)
|
||||
|
||||
# Apply guardrail result if available
|
||||
if guardrail_result.result is not None:
|
||||
if isinstance(guardrail_result.result, str):
|
||||
output.raw = guardrail_result.result
|
||||
@@ -747,7 +734,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
usage_metrics = self._token_process.get_summary()
|
||||
output.usage_metrics = usage_metrics.model_dump() if usage_metrics else None
|
||||
|
||||
# Emit completion event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=LiteAgentExecutionCompletedEvent(
|
||||
@@ -808,7 +794,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""
|
||||
base_prompt = ""
|
||||
if self._parsed_tools:
|
||||
# Use the prompt template for agents with tools
|
||||
base_prompt = I18N_DEFAULT.slice(
|
||||
"lite_agent_system_prompt_with_tools"
|
||||
).format(
|
||||
@@ -819,7 +804,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
tool_names=get_tool_names(self._parsed_tools),
|
||||
)
|
||||
else:
|
||||
# Use the prompt template for agents without tools
|
||||
base_prompt = I18N_DEFAULT.slice(
|
||||
"lite_agent_system_prompt_without_tools"
|
||||
).format(
|
||||
@@ -856,15 +840,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
system_prompt = self._get_default_system_prompt(response_format=response_format)
|
||||
|
||||
# Add system message at the beginning
|
||||
formatted_messages: list[LLMMessage] = [
|
||||
{"role": "system", "content": system_prompt}
|
||||
]
|
||||
|
||||
# Add the rest of the messages
|
||||
formatted_messages.extend(messages)
|
||||
|
||||
# Attach files to the last user message if provided
|
||||
if input_files:
|
||||
for msg in reversed(formatted_messages):
|
||||
if msg.get("role") == "user":
|
||||
@@ -885,7 +866,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
Returns:
|
||||
AgentFinish: The final result of the agent execution.
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
@@ -963,7 +943,6 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
|
||||
@@ -114,7 +114,6 @@ MAX_CONTEXT: Final[int] = 2097152 # Current max from gemini-1.5-pro
|
||||
ANTHROPIC_PREFIXES: Final[tuple[str, str, str]] = ("anthropic/", "claude-", "claude/")
|
||||
|
||||
LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
# openai
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
"gpt-4o-mini": 200000,
|
||||
@@ -126,7 +125,6 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"o1-mini": 128000,
|
||||
"o3-mini": 200000,
|
||||
"o4-mini": 200000,
|
||||
# gemini
|
||||
"gemini-3-pro-preview": 1048576,
|
||||
"gemini-2.0-flash": 1048576,
|
||||
"gemini-2.0-flash-thinking-exp-01-21": 32768,
|
||||
@@ -141,9 +139,7 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"gemini/gemma-3-4b-it": 128000,
|
||||
"gemini/gemma-3-12b-it": 128000,
|
||||
"gemini/gemma-3-27b-it": 128000,
|
||||
# deepseek
|
||||
"deepseek-chat": 128000,
|
||||
# groq
|
||||
"gemma2-9b-it": 8192,
|
||||
"gemma-7b-it": 8192,
|
||||
"llama3-groq-70b-8192-tool-use-preview": 8192,
|
||||
@@ -159,7 +155,6 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"mixtral-8x7b-32768": 32768,
|
||||
"llama-3.3-70b-versatile": 128000,
|
||||
"llama-3.3-70b-instruct": 128000,
|
||||
# sambanova
|
||||
"Meta-Llama-3.3-70B-Instruct": 131072,
|
||||
"QwQ-32B-Preview": 8192,
|
||||
"Qwen2.5-72B-Instruct": 8192,
|
||||
@@ -171,11 +166,9 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"Llama-3.2-11B-Vision-Instruct": 16384,
|
||||
"Meta-Llama-3.2-3B-Instruct": 4096,
|
||||
"Meta-Llama-3.2-1B-Instruct": 16384,
|
||||
# bedrock
|
||||
"us.amazon.nova-pro-v1:0": 300000,
|
||||
"us.amazon.nova-micro-v1:0": 128000,
|
||||
"us.amazon.nova-lite-v1:0": 300000,
|
||||
# Claude 4 models
|
||||
"us.anthropic.claude-opus-4-7": 1000000,
|
||||
"us.anthropic.claude-sonnet-4-6": 1000000,
|
||||
"us.anthropic.claude-opus-4-6-v1": 1000000,
|
||||
@@ -203,7 +196,6 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"eu.anthropic.claude-3-5-sonnet-20240620-v1:0": 200000,
|
||||
"eu.anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"eu.anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
# Claude 4 EU
|
||||
"eu.anthropic.claude-opus-4-7": 1000000,
|
||||
"eu.anthropic.claude-sonnet-4-6": 1000000,
|
||||
"eu.anthropic.claude-opus-4-6-v1": 1000000,
|
||||
@@ -219,7 +211,6 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"apac.anthropic.claude-3-5-sonnet-20241022-v2:0": 200000,
|
||||
"apac.anthropic.claude-3-sonnet-20240229-v1:0": 200000,
|
||||
"apac.anthropic.claude-3-haiku-20240307-v1:0": 200000,
|
||||
# Claude 4 APAC
|
||||
"apac.anthropic.claude-opus-4-7": 1000000,
|
||||
"apac.anthropic.claude-sonnet-4-6": 1000000,
|
||||
"apac.anthropic.claude-opus-4-6-v1": 1000000,
|
||||
@@ -264,7 +255,6 @@ LLM_CONTEXT_WINDOW_SIZES: Final[dict[str, int]] = {
|
||||
"ai21.jamba-instruct-v1:0": 256000,
|
||||
"mistral.mistral-7b-instruct-v0:2": 32000,
|
||||
"mistral.mixtral-8x7b-instruct-v0:1": 32000,
|
||||
# mistral
|
||||
"mistral-tiny": 32768,
|
||||
"mistral-small-latest": 32768,
|
||||
"mistral-medium-latest": 32768,
|
||||
@@ -291,7 +281,6 @@ SUPPORTED_NATIVE_PROVIDERS: Final[list[str]] = [
|
||||
"gemini",
|
||||
"bedrock",
|
||||
"aws",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
"ollama",
|
||||
@@ -380,7 +369,6 @@ class LLM(BaseLLM):
|
||||
"gemini": "gemini",
|
||||
"bedrock": "bedrock",
|
||||
"aws": "bedrock",
|
||||
# OpenAI-compatible providers
|
||||
"openrouter": "openrouter",
|
||||
"deepseek": "deepseek",
|
||||
"ollama": "ollama",
|
||||
@@ -421,7 +409,6 @@ class LLM(BaseLLM):
|
||||
except Exception as e:
|
||||
raise ImportError(f"Error importing native provider: {e}") from e
|
||||
|
||||
# FALLBACK to LiteLLM
|
||||
if not LITELLM_AVAILABLE:
|
||||
native_list = ", ".join(SUPPORTED_NATIVE_PROVIDERS)
|
||||
error_msg = (
|
||||
@@ -542,7 +529,6 @@ class LLM(BaseLLM):
|
||||
# azure does not provide a list of available models, determine a better way to handle this
|
||||
return True
|
||||
|
||||
# Fallback to pattern matching for models not in constants
|
||||
return cls._matches_provider_pattern(model, provider)
|
||||
|
||||
@classmethod
|
||||
@@ -606,7 +592,6 @@ class LLM(BaseLLM):
|
||||
|
||||
return BedrockCompletion
|
||||
|
||||
# OpenAI-compatible providers
|
||||
openai_compatible_providers = {
|
||||
"openrouter",
|
||||
"deepseek",
|
||||
@@ -672,15 +657,12 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
Dict[str, Any]: Parameters for the completion call
|
||||
"""
|
||||
# --- 1) Format messages according to provider requirements
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
# --- 1a) Process any file attachments into multimodal content
|
||||
if not skip_file_processing:
|
||||
messages = self._process_message_files(messages)
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"messages": formatted_messages,
|
||||
@@ -709,7 +691,6 @@ class LLM(BaseLLM):
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
def _handle_streaming_response(
|
||||
@@ -737,7 +718,6 @@ class LLM(BaseLLM):
|
||||
Raises:
|
||||
Exception: If no content is received from the streaming response
|
||||
"""
|
||||
# --- 1) Initialize response tracking
|
||||
full_response = ""
|
||||
last_chunk = None
|
||||
chunk_count = 0
|
||||
@@ -747,33 +727,27 @@ class LLM(BaseLLM):
|
||||
AccumulatedToolArgs
|
||||
)
|
||||
|
||||
# --- 2) Make sure stream is set to True and include usage metrics
|
||||
params["stream"] = True
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
try:
|
||||
# --- 3) Process each chunk in the stream
|
||||
for chunk in litellm.completion(**params):
|
||||
chunk_count += 1
|
||||
last_chunk = chunk
|
||||
|
||||
# Extract content from the chunk
|
||||
chunk_content = None
|
||||
response_id = None
|
||||
|
||||
if isinstance(chunk, ModelResponseBase):
|
||||
response_id = chunk.id
|
||||
|
||||
# Safely extract content from various chunk formats
|
||||
try:
|
||||
# Try to access choices safely
|
||||
choices = None
|
||||
if isinstance(chunk, dict) and "choices" in chunk:
|
||||
choices = chunk["choices"]
|
||||
elif isinstance(chunk, ModelResponseStream):
|
||||
choices = chunk.choices
|
||||
|
||||
# Try to extract usage information if available
|
||||
# NOTE: usage is a pydantic extra field on ModelResponseBase,
|
||||
# so it must be accessed via model_extra.
|
||||
if isinstance(chunk, dict) and "usage" in chunk:
|
||||
@@ -784,29 +758,23 @@ class LLM(BaseLLM):
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
# Handle different delta formats
|
||||
delta = None
|
||||
if isinstance(choice, dict) and "delta" in choice:
|
||||
delta = choice["delta"]
|
||||
elif isinstance(choice, LiteLLMStreamingChoices):
|
||||
delta = choice.delta
|
||||
|
||||
# Extract content from delta
|
||||
if delta:
|
||||
# Handle dict format
|
||||
if isinstance(delta, dict):
|
||||
if "content" in delta and delta["content"] is not None:
|
||||
chunk_content = delta["content"]
|
||||
# Handle object format
|
||||
elif isinstance(delta, LiteLLMDelta):
|
||||
chunk_content = delta.content
|
||||
|
||||
# Handle case where content might be None or empty
|
||||
if chunk_content is None and isinstance(delta, dict):
|
||||
# Some models might send empty content chunks
|
||||
chunk_content = ""
|
||||
|
||||
# Enable tool calls using streaming
|
||||
if "tool_calls" in delta:
|
||||
tool_calls = delta["tool_calls"]
|
||||
if tool_calls:
|
||||
@@ -826,9 +794,7 @@ class LLM(BaseLLM):
|
||||
logging.debug(f"Error extracting content from chunk: {e}")
|
||||
logging.debug(f"Chunk format: {type(chunk)}, content: {chunk}")
|
||||
|
||||
# Only add non-None content to the response
|
||||
if chunk_content is not None:
|
||||
# Add the chunk content to the full response
|
||||
full_response += chunk_content
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -842,16 +808,13 @@ class LLM(BaseLLM):
|
||||
call_id=get_current_call_id(),
|
||||
),
|
||||
)
|
||||
# --- 4) Fallback to non-streaming if no content received
|
||||
if not full_response.strip() and chunk_count == 0:
|
||||
logging.warning(
|
||||
"No chunks received in streaming response, falling back to non-streaming"
|
||||
)
|
||||
non_streaming_params = params.copy()
|
||||
non_streaming_params["stream"] = False
|
||||
non_streaming_params.pop(
|
||||
"stream_options", None
|
||||
) # Remove stream_options for non-streaming call
|
||||
non_streaming_params.pop("stream_options", None)
|
||||
return self._handle_non_streaming_response(
|
||||
non_streaming_params,
|
||||
callbacks,
|
||||
@@ -860,14 +823,12 @@ class LLM(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# --- 5) Handle empty response with chunks
|
||||
if not full_response.strip() and chunk_count > 0:
|
||||
logging.warning(
|
||||
f"Received {chunk_count} chunks but no content was extracted"
|
||||
)
|
||||
if last_chunk is not None:
|
||||
try:
|
||||
# Try to extract content from the last chunk's message
|
||||
choices = None
|
||||
if isinstance(last_chunk, dict) and "choices" in last_chunk:
|
||||
choices = last_chunk["choices"]
|
||||
@@ -877,7 +838,6 @@ class LLM(BaseLLM):
|
||||
if choices and len(choices) > 0:
|
||||
choice = choices[0]
|
||||
|
||||
# Try to get content from message
|
||||
message = None
|
||||
if isinstance(choice, dict) and "message" in choice:
|
||||
message = choice["message"]
|
||||
@@ -902,13 +862,11 @@ class LLM(BaseLLM):
|
||||
f"Last chunk format: {type(last_chunk)}, content: {last_chunk}"
|
||||
)
|
||||
|
||||
# --- 6) If still empty, raise an error instead of using a default response
|
||||
if not full_response.strip() and len(accumulated_tool_args) == 0:
|
||||
raise Exception(
|
||||
"No content received from streaming response. Received empty chunks or failed to extract content."
|
||||
)
|
||||
|
||||
# --- 7) Check for tool calls in the final response
|
||||
tool_calls = None
|
||||
try:
|
||||
if last_chunk:
|
||||
@@ -935,7 +893,6 @@ class LLM(BaseLLM):
|
||||
except Exception as e:
|
||||
logging.debug(f"Error checking for tool calls: {e}")
|
||||
|
||||
# Track token usage and log callbacks if available in streaming mode
|
||||
if usage_info:
|
||||
self._track_token_usage_internal(usage_info)
|
||||
self._handle_streaming_callbacks(callbacks, usage_info, last_chunk)
|
||||
@@ -986,12 +943,10 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return full_response
|
||||
|
||||
# --- 9) Handle tool calls if present
|
||||
tool_result = self._handle_tool_call(tool_calls, available_functions)
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 10) Emit completion event and return response
|
||||
usage_dict = self._usage_to_dict(usage_info)
|
||||
self._handle_emit_call_events(
|
||||
response=full_response,
|
||||
@@ -1004,10 +959,8 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
raise
|
||||
except Exception as e:
|
||||
# Check if this is a context window error and convert to our exception type
|
||||
error_msg = str(e)
|
||||
if LLMContextLengthExceededError._is_context_limit_error(error_msg):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
@@ -1101,9 +1054,7 @@ class LLM(BaseLLM):
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
# Use the usage_info we've been tracking
|
||||
if not usage_info:
|
||||
# Try to get usage from the last chunk if we haven't already
|
||||
try:
|
||||
if last_chunk:
|
||||
if (
|
||||
@@ -1152,7 +1103,6 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
str: The response text
|
||||
"""
|
||||
# --- 1) Handle response_model with InternalInstructor for LiteLLM
|
||||
if response_model and self.is_litellm:
|
||||
from crewai.utilities.internal_instructor import InternalInstructor
|
||||
|
||||
@@ -1160,7 +1110,6 @@ class LLM(BaseLLM):
|
||||
if not messages:
|
||||
raise ValueError("Messages are required when using response_model")
|
||||
|
||||
# Combine all message content for InternalInstructor
|
||||
combined_content = "\n\n".join(
|
||||
f"{msg['role'].upper()}: {msg['content']}" for msg in messages
|
||||
)
|
||||
@@ -1197,10 +1146,8 @@ class LLM(BaseLLM):
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
raise
|
||||
except Exception as e:
|
||||
# Check if this is a context window error and convert to our exception type
|
||||
error_msg = str(e)
|
||||
if LLMContextLengthExceededError._is_context_limit_error(error_msg):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
@@ -1212,7 +1159,6 @@ class LLM(BaseLLM):
|
||||
else None
|
||||
)
|
||||
|
||||
# --- 2) Handle structured output response (when response_model is provided)
|
||||
if response_model is not None:
|
||||
# When using instructor/response_model, litellm returns a Pydantic model instance
|
||||
if isinstance(response, BaseModel):
|
||||
@@ -1227,12 +1173,10 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return structured_response
|
||||
|
||||
# --- 3) Extract response message and content (standard response)
|
||||
response_message = cast(Choices, cast(ModelResponse, response).choices)[
|
||||
0
|
||||
].message
|
||||
text_response = response_message.content or ""
|
||||
# --- 3) Handle callbacks with usage info
|
||||
if callbacks and len(callbacks) > 0:
|
||||
for callback in callbacks:
|
||||
if isinstance(callback, TokenCalcHandler):
|
||||
@@ -1249,14 +1193,11 @@ class LLM(BaseLLM):
|
||||
start_time=0,
|
||||
end_time=0,
|
||||
)
|
||||
# --- 4) Check for tool calls
|
||||
tool_calls = response_message.tool_calls or []
|
||||
|
||||
# --- 5) If there are tool calls but no available functions, return the tool calls
|
||||
if tool_calls and not available_functions:
|
||||
return tool_calls
|
||||
|
||||
# --- 6) If there are no tool calls to execute, return the text response directly
|
||||
if not tool_calls and text_response:
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
@@ -1268,7 +1209,6 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
# --- 7) Handle tool calls if present (execute when available_functions provided)
|
||||
if tool_calls and available_functions:
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
@@ -1276,7 +1216,6 @@ class LLM(BaseLLM):
|
||||
if tool_result is not None:
|
||||
return tool_result
|
||||
|
||||
# --- 8) If tool call handling didn't return a result, emit completion event and return text response
|
||||
self._handle_emit_call_events(
|
||||
response=text_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -1348,10 +1287,8 @@ class LLM(BaseLLM):
|
||||
self._track_token_usage_internal(usage_info)
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
raise
|
||||
except Exception as e:
|
||||
# Check if this is a context window error and convert to our exception type
|
||||
error_msg = str(e)
|
||||
if LLMContextLengthExceededError._is_context_limit_error(error_msg):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
@@ -1414,7 +1351,6 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return text_response
|
||||
|
||||
# Handle tool calls if present (execute when available_functions provided)
|
||||
if tool_calls and available_functions:
|
||||
tool_result = self._handle_tool_call(
|
||||
tool_calls, available_functions, from_task, from_agent
|
||||
@@ -1590,10 +1526,8 @@ class LLM(BaseLLM):
|
||||
return full_response
|
||||
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise our own context length error
|
||||
raise
|
||||
except Exception as e:
|
||||
# Check if this is a context window error and convert to our exception type
|
||||
error_msg = str(e)
|
||||
if LLMContextLengthExceededError._is_context_limit_error(error_msg):
|
||||
raise LLMContextLengthExceededError(error_msg) from e
|
||||
@@ -1630,19 +1564,15 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
The result of the tool call, or None if no tool call was made
|
||||
"""
|
||||
# --- 1) Validate tool calls and available functions
|
||||
if not tool_calls or not available_functions:
|
||||
return None
|
||||
|
||||
# --- 2) Extract function name from first tool call
|
||||
tool_call = tool_calls[0]
|
||||
function_name = sanitize_tool_name(tool_call.function.name)
|
||||
function_args = {} # Initialize to empty dict to avoid unbound variable
|
||||
function_args = {}
|
||||
|
||||
# --- 3) Check if function is available
|
||||
if function_name in available_functions:
|
||||
try:
|
||||
# --- 3.1) Parse function arguments
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
fn = available_functions[function_name]
|
||||
|
||||
@@ -1671,7 +1601,6 @@ class LLM(BaseLLM):
|
||||
),
|
||||
)
|
||||
|
||||
# --- 3.3) Emit success event
|
||||
self._handle_emit_call_events(
|
||||
response=result,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
@@ -1680,10 +1609,7 @@ class LLM(BaseLLM):
|
||||
)
|
||||
return result
|
||||
except Exception as e:
|
||||
# --- 3.4) Handle execution errors
|
||||
fn = available_functions.get(
|
||||
function_name, lambda: None
|
||||
) # Ensure fn is always a callable
|
||||
fn = available_functions.get(function_name, lambda: None)
|
||||
logging.error(f"Error executing function '{function_name}': {e}")
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -1757,13 +1683,10 @@ class LLM(BaseLLM):
|
||||
),
|
||||
)
|
||||
|
||||
# --- 2) Validate parameters before proceeding with the call
|
||||
self._validate_call_params()
|
||||
|
||||
# --- 3) Convert string messages to proper format if needed
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
# --- 4) Handle O1 model special case (system messages not supported)
|
||||
if "o1" in self.model.lower():
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
@@ -1773,14 +1696,11 @@ class LLM(BaseLLM):
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
self.set_callbacks(callbacks)
|
||||
try:
|
||||
# --- 6) Prepare parameters for the completion call
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
@@ -1912,7 +1832,6 @@ class LLM(BaseLLM):
|
||||
if isinstance(messages, str):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
|
||||
# Process file attachments asynchronously before preparing params
|
||||
messages = await self._aprocess_message_files(messages)
|
||||
|
||||
if "o1" in self.model.lower():
|
||||
@@ -2159,18 +2078,15 @@ class LLM(BaseLLM):
|
||||
if messages is None:
|
||||
raise TypeError("Messages cannot be None")
|
||||
|
||||
# Validate message format first
|
||||
for msg in messages:
|
||||
if not isinstance(msg, dict) or "role" not in msg or "content" not in msg:
|
||||
raise TypeError(
|
||||
"Invalid message format. Each message must be a dict with 'role' and 'content' keys"
|
||||
)
|
||||
|
||||
# Handle O1 models specially
|
||||
if "o1" in self.model.lower():
|
||||
formatted_messages = []
|
||||
for msg in messages:
|
||||
# Convert system messages to assistant messages
|
||||
if msg["role"] == "system":
|
||||
formatted_messages.append(
|
||||
{"role": "assistant", "content": msg["content"]}
|
||||
@@ -2181,7 +2097,6 @@ class LLM(BaseLLM):
|
||||
|
||||
# Handle Mistral models - they require the last message to have a role of 'user' or 'tool'
|
||||
if "mistral" in self.model.lower():
|
||||
# Check if the last message has a role of 'assistant'
|
||||
if messages and messages[-1]["role"] == "assistant":
|
||||
return [*messages, {"role": "user", "content": "Please continue."}] # type: ignore[list-item]
|
||||
return messages # type: ignore[return-value]
|
||||
@@ -2195,13 +2110,11 @@ class LLM(BaseLLM):
|
||||
):
|
||||
return [*messages, {"role": "user", "content": ""}] # type: ignore[list-item]
|
||||
|
||||
# Handle Anthropic models
|
||||
if not self.is_anthropic:
|
||||
return messages # type: ignore[return-value]
|
||||
|
||||
# Anthropic requires messages to start with 'user' role
|
||||
if not messages or messages[0]["role"] == "system":
|
||||
# If first message is system or empty, add a placeholder user message
|
||||
return [{"role": "user", "content": "."}, *messages] # type: ignore[list-item]
|
||||
|
||||
return messages # type: ignore[return-value]
|
||||
@@ -2230,7 +2143,6 @@ class LLM(BaseLLM):
|
||||
Native providers have their own validation.
|
||||
"""
|
||||
if not LITELLM_AVAILABLE or supports_response_schema is None:
|
||||
# When litellm is not available, skip validation
|
||||
# (this path should only be reached for litellm fallback models)
|
||||
return
|
||||
|
||||
@@ -2299,7 +2211,6 @@ class LLM(BaseLLM):
|
||||
min_context = 1024
|
||||
max_context = 2097152 # Current max from gemini-1.5-pro
|
||||
|
||||
# Validate all context window sizes
|
||||
for key, value in LLM_CONTEXT_WINDOW_SIZES.items():
|
||||
if value < min_context or value > max_context:
|
||||
raise ValueError(
|
||||
@@ -2324,7 +2235,6 @@ class LLM(BaseLLM):
|
||||
don't use litellm callbacks - they emit events via base_llm.py.
|
||||
"""
|
||||
if not LITELLM_AVAILABLE:
|
||||
# When litellm is not available, callbacks are still stored
|
||||
# but not registered with litellm globals
|
||||
return
|
||||
|
||||
@@ -2363,7 +2273,6 @@ class LLM(BaseLLM):
|
||||
`litellm.failure_callback` to ["langfuse"].
|
||||
"""
|
||||
if not LITELLM_AVAILABLE:
|
||||
# When litellm is not available, env callbacks have no effect
|
||||
return
|
||||
|
||||
with suppress_warnings():
|
||||
@@ -2417,7 +2326,6 @@ class LLM(BaseLLM):
|
||||
]
|
||||
}
|
||||
|
||||
# Create a new instance with the same parameters
|
||||
return LLM(
|
||||
model=self.model,
|
||||
is_litellm=self.is_litellm,
|
||||
@@ -2481,7 +2389,6 @@ class LLM(BaseLLM):
|
||||
]
|
||||
}
|
||||
|
||||
# Create a new instance with the same parameters
|
||||
return LLM(
|
||||
model=self.model,
|
||||
is_litellm=self.is_litellm,
|
||||
@@ -2524,45 +2431,33 @@ class LLM(BaseLLM):
|
||||
True if the model likely supports images.
|
||||
"""
|
||||
vision_prefixes = (
|
||||
# OpenAI — GPT-4 vision models
|
||||
"gpt-4o",
|
||||
"gpt-4-turbo",
|
||||
"gpt-4-vision",
|
||||
"gpt-4.1",
|
||||
# OpenAI — GPT-5 family (all variants support multimodal)
|
||||
"gpt-5",
|
||||
# OpenAI — o-series reasoning models with vision
|
||||
# o1, o3, o4, o4-mini support multimodal
|
||||
# o1-mini, o1-preview, o3-mini are text-only — handled via exclusion below
|
||||
"o1",
|
||||
"o3",
|
||||
"o4-mini",
|
||||
"o4",
|
||||
# Anthropic — Claude 3+ models support vision
|
||||
"claude-3",
|
||||
"claude-4",
|
||||
"claude-sonnet-4",
|
||||
"claude-opus-4",
|
||||
"claude-haiku-4",
|
||||
# Google — all Gemini models support multimodal
|
||||
"gemini",
|
||||
# xAI — Grok models support vision
|
||||
"grok",
|
||||
# Mistral — Pixtral vision model
|
||||
"pixtral",
|
||||
# Open-source vision models
|
||||
"llava",
|
||||
# Alibaba — Qwen vision-language models
|
||||
"qwen-vl",
|
||||
"qwen2-vl",
|
||||
"qwen3-vl",
|
||||
)
|
||||
# Text-only models that would otherwise match vision prefixes
|
||||
text_only_models = ("o3-mini", "o1-mini", "o1-preview")
|
||||
|
||||
model_lower = self.model.lower()
|
||||
|
||||
# Check exclusion first
|
||||
if any(
|
||||
model_lower.startswith(m) or f"/{m}" in model_lower
|
||||
for m in text_only_models
|
||||
|
||||
@@ -227,7 +227,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
if not data.get("model"):
|
||||
raise ValueError("Model name is required and cannot be empty")
|
||||
|
||||
# Normalize stop: accept str, list, or None; also accept stop_sequences alias
|
||||
stop_seqs = data.pop("stop_sequences", None)
|
||||
stop = stop_seqs if stop_seqs is not None else data.get("stop")
|
||||
if stop is None:
|
||||
@@ -239,11 +238,9 @@ class BaseLLM(BaseModel, ABC):
|
||||
else:
|
||||
data["stop"] = list(stop)
|
||||
|
||||
# Default provider
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
|
||||
# Collect unknown kwargs into additional_params
|
||||
known_fields = set(cls.model_fields.keys())
|
||||
extras = {k: v for k, v in data.items() if k not in known_fields}
|
||||
for k in extras:
|
||||
@@ -417,7 +414,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
earliest_stop_pos = stop_pos
|
||||
found_stop_word = stop_word
|
||||
|
||||
# Truncate at the stop word if found
|
||||
if found_stop_word is not None:
|
||||
truncated = content[:earliest_stop_pos].strip()
|
||||
logging.debug(
|
||||
@@ -433,7 +429,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
Returns:
|
||||
The number of tokens/characters the model can handle.
|
||||
"""
|
||||
# Default implementation - subclasses should override with model-specific values
|
||||
return DEFAULT_CONTEXT_WINDOW_SIZE
|
||||
|
||||
def supports_multimodal(self) -> bool:
|
||||
@@ -469,8 +464,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
"""
|
||||
return None
|
||||
|
||||
# Common helper methods for native SDK implementations
|
||||
|
||||
def _emit_call_started_event(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -626,7 +619,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
return None
|
||||
|
||||
try:
|
||||
# Emit tool usage started event
|
||||
started_at = datetime.now()
|
||||
|
||||
crewai_event_bus.emit(
|
||||
@@ -639,11 +631,9 @@ class BaseLLM(BaseModel, ABC):
|
||||
),
|
||||
)
|
||||
|
||||
# Execute the function
|
||||
fn = available_functions[function_name]
|
||||
result = fn(**function_args)
|
||||
|
||||
# Emit tool usage finished event
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=ToolUsageFinishedEvent(
|
||||
@@ -657,7 +647,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
),
|
||||
)
|
||||
|
||||
# Emit LLM call completed event for tool call
|
||||
self._emit_call_completed_event(
|
||||
response=result,
|
||||
call_type=LLMCallType.TOOL_CALL,
|
||||
@@ -671,7 +660,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
error_msg = f"Error executing function '{function_name}': {e!s}"
|
||||
logging.error(error_msg)
|
||||
|
||||
# Emit tool usage error event
|
||||
if not hasattr(crewai_event_bus, "emit"):
|
||||
raise ValueError(
|
||||
"crewai_event_bus does not have an emit method"
|
||||
@@ -688,7 +676,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
),
|
||||
)
|
||||
|
||||
# Emit LLM call failed event
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg,
|
||||
from_task=from_task,
|
||||
@@ -808,7 +795,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
return response
|
||||
|
||||
try:
|
||||
# Try to parse as JSON first
|
||||
if response.strip().startswith("{") or response.strip().startswith("["):
|
||||
data = json.loads(response)
|
||||
return response_format.model_validate(data)
|
||||
@@ -846,7 +832,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
Args:
|
||||
usage_data: Token usage data from the API response
|
||||
"""
|
||||
# Extract tokens in a provider-agnostic way
|
||||
prompt_tokens = (
|
||||
usage_data.get("prompt_tokens")
|
||||
or usage_data.get("prompt_token_count")
|
||||
@@ -915,7 +900,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
... ):
|
||||
... raise ValueError("LLM call blocked by hook")
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None:
|
||||
return True
|
||||
|
||||
@@ -985,7 +969,6 @@ class BaseLLM(BaseModel, ABC):
|
||||
... messages, result, from_agent
|
||||
... )
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None or not isinstance(response, str):
|
||||
return response
|
||||
|
||||
|
||||
@@ -299,7 +299,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
"""
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
@@ -309,7 +308,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Anthropic
|
||||
formatted_messages, system_message = (
|
||||
self._format_messages_for_anthropic(messages)
|
||||
)
|
||||
@@ -319,14 +317,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools, available_functions
|
||||
)
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
@@ -448,11 +444,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
params["system"] = system_message
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
@@ -460,7 +454,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if self.stop_sequences:
|
||||
params["stop_sequences"] = self.stop_sequences
|
||||
|
||||
# Handle tools for Claude 3+
|
||||
if tools and self.supports_tools:
|
||||
converted_tools = self._convert_tools_for_interference(tools)
|
||||
|
||||
@@ -498,7 +491,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
# Pass through tool search tool definitions unchanged
|
||||
tool_type = tool.get("type", "")
|
||||
if tool_type in TOOL_SEARCH_TOOL_TYPES:
|
||||
anthropic_tools.append(tool)
|
||||
@@ -560,7 +552,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if self.tool_search is None:
|
||||
return tools
|
||||
|
||||
# Check if a tool search tool is already present (user passed one manually)
|
||||
has_search_tool = any(
|
||||
t.get("type", "") in TOOL_SEARCH_TOOL_TYPES for t in tools
|
||||
)
|
||||
@@ -568,23 +559,19 @@ class AnthropicCompletion(BaseLLM):
|
||||
result: list[dict[str, Any]] = []
|
||||
|
||||
if not has_search_tool:
|
||||
# Map config type to API type identifier
|
||||
type_map = {
|
||||
"regex": "tool_search_tool_regex_20251119",
|
||||
"bm25": "tool_search_tool_bm25_20251119",
|
||||
}
|
||||
tool_type = type_map[self.tool_search.type]
|
||||
# Tool search tool names follow the convention: tool_search_tool_{variant}
|
||||
tool_name = f"tool_search_tool_{self.tool_search.type}"
|
||||
result.append({"type": tool_type, "name": tool_name})
|
||||
|
||||
for tool in tools:
|
||||
# Don't modify tool search tools
|
||||
if tool.get("type", "") in TOOL_SEARCH_TOOL_TYPES:
|
||||
result.append(tool)
|
||||
continue
|
||||
|
||||
# Mark regular tools as deferred if not already set
|
||||
if "defer_loading" not in tool:
|
||||
tool = {**tool, "defer_loading": True}
|
||||
result.append(tool)
|
||||
@@ -724,7 +711,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if len(text_blocks) == 1 and isinstance(text_blocks[0], str):
|
||||
cache_match_contents.append(text_blocks[0])
|
||||
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
formatted_messages: list[LLMMessage] = []
|
||||
@@ -752,14 +738,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
}
|
||||
pending_tool_results.append(tool_result)
|
||||
elif role == "assistant":
|
||||
# First, flush any pending tool results as a user message
|
||||
if pending_tool_results:
|
||||
formatted_messages.append(
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
)
|
||||
pending_tool_results = []
|
||||
|
||||
# Handle assistant message with tool_calls (convert to Anthropic format)
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
if tool_calls:
|
||||
assistant_content: list[dict[str, Any]] = []
|
||||
@@ -798,7 +782,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
LLMMessage(role="assistant", content=content_str)
|
||||
)
|
||||
else:
|
||||
# User message - first flush any pending tool results
|
||||
if pending_tool_results:
|
||||
formatted_messages.append(
|
||||
{"role": "user", "content": pending_tool_results}
|
||||
@@ -819,16 +802,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
LLMMessage(role=role_str, content=content_str)
|
||||
)
|
||||
|
||||
# Flush any remaining pending tool results
|
||||
if pending_tool_results:
|
||||
formatted_messages.append({"role": "user", "content": pending_tool_results})
|
||||
|
||||
# Ensure first message is from user (Anthropic requirement)
|
||||
# Anthropic requires the first message to come from "user"
|
||||
if not formatted_messages:
|
||||
# If no messages, add a default user message
|
||||
formatted_messages.append({"role": "user", "content": "Hello"})
|
||||
elif formatted_messages[0]["role"] != "user":
|
||||
# If first message is not from user, insert a user message at the beginning
|
||||
formatted_messages.insert(0, {"role": "user", "content": "Hello"})
|
||||
|
||||
# Stamp cache_control on the message(s) whose original content was
|
||||
@@ -983,9 +963,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
# This allows the executor to manage tool execution with proper
|
||||
# message history and post-tool reasoning prompts
|
||||
# Without available_functions, return tool calls so the executor can
|
||||
# manage execution with proper message history and post-tool reasoning prompts
|
||||
if not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(tool_uses),
|
||||
@@ -1207,7 +1186,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if not available_functions:
|
||||
return list(tool_uses)
|
||||
|
||||
# Execute first tool and return result directly
|
||||
result = self._execute_first_tool(
|
||||
tool_uses, available_functions, from_task, from_agent
|
||||
)
|
||||
@@ -1330,7 +1308,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
follow_up_params = params.copy()
|
||||
|
||||
# Add Claude's tool use response to conversation
|
||||
assistant_content: list[
|
||||
ThinkingBlock | ToolUseBlock | TextBlock | dict[str, Any]
|
||||
] = []
|
||||
@@ -1352,22 +1329,18 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
assistant_message = {"role": "assistant", "content": assistant_content}
|
||||
|
||||
# Add user message with tool results
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
# Update messages for follow-up call
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self._get_sync_client().messages.create(
|
||||
**follow_up_params
|
||||
)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
@@ -1388,7 +1361,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
# Emit completion event for the final response
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -1398,7 +1370,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage=follow_up_usage,
|
||||
)
|
||||
|
||||
# Log combined token usage
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
@@ -1416,7 +1387,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
# Fallback: return the first tool result if follow-up fails
|
||||
# Fallback to first tool result when follow-up fails
|
||||
if tool_results:
|
||||
return cast(str, tool_results[0]["content"])
|
||||
raise e
|
||||
@@ -1516,7 +1487,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
]
|
||||
|
||||
if tool_uses:
|
||||
# If no available_functions, return tool calls for executor to handle
|
||||
if not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(tool_uses),
|
||||
@@ -1825,7 +1795,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for Anthropic models
|
||||
context_windows = {
|
||||
"claude-3-5-sonnet": 200000,
|
||||
"claude-3-5-haiku": 200000,
|
||||
@@ -1838,12 +1807,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
"claude-instant": 100000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Claude models
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -90,7 +90,6 @@ class AzureCompletion(BaseLLM):
|
||||
is_azure_openai_endpoint: bool = False
|
||||
credential_scopes: list[str] | None = None
|
||||
|
||||
# Responses API settings
|
||||
api: Literal["completions", "responses"] = "completions"
|
||||
reasoning_effort: str | None = None
|
||||
instructions: str | None = None
|
||||
@@ -119,7 +118,6 @@ class AzureCompletion(BaseLLM):
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = data.get("api_key") or os.getenv("AZURE_API_KEY")
|
||||
data["endpoint"] = (
|
||||
data.get("endpoint")
|
||||
@@ -506,7 +504,6 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
@@ -518,7 +515,6 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
effective_response_model = response_model or self.response_format
|
||||
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
@@ -526,12 +522,10 @@ class AzureCompletion(BaseLLM):
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, effective_response_model
|
||||
)
|
||||
|
||||
# Handle streaming vs non-streaming
|
||||
if self.stream:
|
||||
return self._handle_streaming_completion(
|
||||
completion_params,
|
||||
@@ -663,12 +657,10 @@ class AzureCompletion(BaseLLM):
|
||||
strict=json_schema_info["strict"],
|
||||
)
|
||||
|
||||
# Only include model parameter for non-Azure OpenAI endpoints
|
||||
# Azure OpenAI endpoints have the deployment name in the URL
|
||||
# Azure OpenAI endpoints embed deployment name in URL and reject model in body
|
||||
if not self.is_azure_openai_endpoint:
|
||||
params["model"] = self.model
|
||||
|
||||
# Add optional parameters if set
|
||||
if self.temperature is not None:
|
||||
params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
@@ -683,7 +675,6 @@ class AzureCompletion(BaseLLM):
|
||||
if stops and self.supports_stop_words():
|
||||
params["stop"] = stops
|
||||
|
||||
# Handle tools/functions for Azure OpenAI models
|
||||
if tools and self.is_openai_model:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
@@ -751,14 +742,13 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
List of dict objects with 'role' and 'content' keys
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
azure_messages: list[LLMMessage] = []
|
||||
|
||||
for message in base_formatted:
|
||||
role = message.get("role", "user") # Default to user if no role
|
||||
# Handle None content - Azure requires string content
|
||||
role = message.get("role", "user")
|
||||
# Azure requires string content; coerce None to ""
|
||||
content = message.get("content") or ""
|
||||
|
||||
if role == "tool":
|
||||
@@ -772,17 +762,15 @@ class AzureCompletion(BaseLLM):
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
# Handle assistant messages with tool_calls
|
||||
elif role == "assistant" and message.get("tool_calls"):
|
||||
tool_calls = message.get("tool_calls", [])
|
||||
azure_msg: LLMMessage = {
|
||||
"role": "assistant",
|
||||
"content": content, # Already defaulted to "" above
|
||||
"content": content,
|
||||
"tool_calls": tool_calls,
|
||||
}
|
||||
azure_messages.append(azure_msg)
|
||||
else:
|
||||
# Azure AI Inference requires both 'role' and 'content'
|
||||
azure_messages.append({"role": role, "content": content})
|
||||
|
||||
return azure_messages
|
||||
@@ -857,12 +845,10 @@ class AzureCompletion(BaseLLM):
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Extract and track token usage
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
# Without available_functions, return tool_calls so the caller (executor) handles execution
|
||||
if message.tool_calls and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(message.tool_calls),
|
||||
@@ -874,7 +860,6 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
@@ -886,7 +871,6 @@ class AzureCompletion(BaseLLM):
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -898,7 +882,6 @@ class AzureCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Extract content
|
||||
content = message.content or ""
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
@@ -913,7 +896,6 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
# Emit completion event and return content
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -1059,8 +1041,7 @@ class AzureCompletion(BaseLLM):
|
||||
usage=usage_data,
|
||||
)
|
||||
|
||||
# If there are tool_calls but no available_functions, return them
|
||||
# in OpenAI-compatible format for executor to handle
|
||||
# Without available_functions, return tool calls in OpenAI-compatible format for the executor
|
||||
if tool_calls and not available_functions:
|
||||
formatted_tool_calls = [
|
||||
{
|
||||
@@ -1083,7 +1064,6 @@ class AzureCompletion(BaseLLM):
|
||||
)
|
||||
return formatted_tool_calls
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
@@ -1094,7 +1074,6 @@ class AzureCompletion(BaseLLM):
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -1106,10 +1085,8 @@ class AzureCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
# Emit completion event and return full response
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -1237,7 +1214,6 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
# Azure OpenAI models support function calling
|
||||
return self.is_openai_model
|
||||
|
||||
def supports_stop_words(self) -> bool:
|
||||
@@ -1277,7 +1253,6 @@ class AzureCompletion(BaseLLM):
|
||||
f"Context window for {key} must be between {min_context} and {max_context}"
|
||||
)
|
||||
|
||||
# Context window sizes for common Azure models
|
||||
context_windows = {
|
||||
"gpt-4": 8192,
|
||||
"gpt-4o": 128000,
|
||||
@@ -1288,14 +1263,12 @@ class AzureCompletion(BaseLLM):
|
||||
"text-embedding": 8191,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in sorted(
|
||||
context_windows.items(), key=lambda x: len(x[0]), reverse=True
|
||||
):
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
@staticmethod
|
||||
|
||||
@@ -69,7 +69,6 @@ def _preprocess_structured_data(
|
||||
import re
|
||||
from typing import get_origin
|
||||
|
||||
# Get model field annotations
|
||||
model_fields = response_model.model_fields
|
||||
|
||||
processed_data = dict(data)
|
||||
@@ -80,17 +79,14 @@ def _preprocess_structured_data(
|
||||
|
||||
value = processed_data[field_name]
|
||||
|
||||
# Check if the field expects a list type
|
||||
annotation = field_info.annotation
|
||||
origin = get_origin(annotation)
|
||||
|
||||
# Handle list[X] or List[X] types
|
||||
is_list_type = origin is list or (
|
||||
origin is not None and str(origin).startswith("list")
|
||||
)
|
||||
|
||||
if is_list_type and isinstance(value, str):
|
||||
# Try to parse markdown-style bullet points or numbered lists
|
||||
lines = value.strip().split("\n")
|
||||
parsed_items = []
|
||||
|
||||
@@ -99,8 +95,7 @@ def _preprocess_structured_data(
|
||||
if not line:
|
||||
continue
|
||||
|
||||
# Remove common bullet point prefixes
|
||||
# Matches: "- item", "* item", "• item", "1. item", "1) item"
|
||||
# Strip common list markers: "- item", "* item", "• item", "1. item", "1) item"
|
||||
cleaned = re.sub(r"^[-*•]\s*", "", line)
|
||||
cleaned = re.sub(r"^\d+[.)]\s*", "", cleaned)
|
||||
cleaned = cleaned.strip()
|
||||
@@ -266,11 +261,9 @@ class BedrockCompletion(BaseLLM):
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Force provider to bedrock
|
||||
data.pop("provider", None)
|
||||
data["provider"] = "bedrock"
|
||||
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
@@ -279,7 +272,6 @@ class BedrockCompletion(BaseLLM):
|
||||
seqs = list(seqs)
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["aws_access_key_id"] = data.get("aws_access_key_id") or os.getenv(
|
||||
"AWS_ACCESS_KEY_ID"
|
||||
)
|
||||
@@ -372,7 +364,6 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
with llm_call_context():
|
||||
try:
|
||||
# Emit call started event
|
||||
self._emit_call_started_event(
|
||||
messages=messages,
|
||||
tools=tools,
|
||||
@@ -382,7 +373,6 @@ class BedrockCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages
|
||||
)
|
||||
@@ -392,20 +382,17 @@ class BedrockCompletion(BaseLLM):
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
}
|
||||
|
||||
# Add system message if present
|
||||
if system_message:
|
||||
body["system"] = cast(
|
||||
"list[SystemContentBlockTypeDef]",
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
# Bedrock requires toolConfig when messages contain toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
@@ -415,7 +402,6 @@ class BedrockCompletion(BaseLLM):
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
@@ -425,7 +411,6 @@ class BedrockCompletion(BaseLLM):
|
||||
cast(object, {"tools": tools_from_history}),
|
||||
)
|
||||
|
||||
# Add optional advanced features if configured
|
||||
if self.guardrail_config:
|
||||
guardrail_config: GuardrailConfigurationTypeDef = cast(
|
||||
"GuardrailConfigurationTypeDef",
|
||||
@@ -535,8 +520,7 @@ class BedrockCompletion(BaseLLM):
|
||||
cast(object, [{"text": system_message}]),
|
||||
)
|
||||
|
||||
# Add tool config if present or if messages contain tool content
|
||||
# Bedrock requires toolConfig when messages have toolUse/toolResult
|
||||
# Bedrock requires toolConfig when messages contain toolUse/toolResult
|
||||
if tools:
|
||||
tool_config: ToolConfigurationTypeDef = {
|
||||
"tools": cast(
|
||||
@@ -546,7 +530,6 @@ class BedrockCompletion(BaseLLM):
|
||||
}
|
||||
body["toolConfig"] = tool_config
|
||||
elif self._messages_contain_tool_content(formatted_messages):
|
||||
# Create minimal toolConfig from tool history in messages
|
||||
tools_from_history = self._extract_tools_from_message_history(
|
||||
formatted_messages
|
||||
)
|
||||
@@ -743,7 +726,6 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
# Filter out structured_output from tool_uses returned to executor
|
||||
non_structured_output_tool_uses = [
|
||||
tu for tu in tool_uses if tu.get("name") != STRUCTURED_OUTPUT_TOOL_NAME
|
||||
]
|
||||
@@ -759,15 +741,12 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
return non_structured_output_tool_uses
|
||||
|
||||
# Process content blocks and handle tool use correctly
|
||||
text_content = ""
|
||||
|
||||
for content_block in content:
|
||||
# Handle text content
|
||||
if "text" in content_block:
|
||||
text_content += content_block["text"]
|
||||
|
||||
# Handle tool use - corrected structure according to AWS API docs
|
||||
elif "toolUse" in content_block and available_functions:
|
||||
tool_use_block = content_block["toolUse"]
|
||||
tool_use_id = tool_use_block.get("toolUseId")
|
||||
@@ -781,7 +760,6 @@ class BedrockCompletion(BaseLLM):
|
||||
f"Tool use requested: {function_name} with ID {tool_use_id}"
|
||||
)
|
||||
|
||||
# Execute the tool
|
||||
tool_result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -821,10 +799,8 @@ class BedrockCompletion(BaseLLM):
|
||||
response_model,
|
||||
)
|
||||
|
||||
# Apply stop sequences if configured
|
||||
text_content = self._apply_stop_words(text_content)
|
||||
|
||||
# Validate final response
|
||||
if not text_content or text_content.strip() == "":
|
||||
logging.warning("Extracted empty text content from Bedrock response")
|
||||
text_content = "I apologize, but I couldn't generate a proper response. Please try again."
|
||||
@@ -845,16 +821,13 @@ class BedrockCompletion(BaseLLM):
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
# Handle all AWS ClientError exceptions as per documentation
|
||||
error_code = e.response.get("Error", {}).get("Code", "Unknown")
|
||||
error_msg = e.response.get("Error", {}).get("Message", str(e))
|
||||
|
||||
# Log the specific error for debugging
|
||||
logging.error(f"AWS Bedrock ClientError ({error_code}): {error_msg}")
|
||||
|
||||
# Handle specific error codes as documented
|
||||
if error_code == "ValidationException":
|
||||
# This is the error we're seeing with Cohere
|
||||
# Cohere returns this when conversation alternation is broken
|
||||
if "last turn" in error_msg and "user message" in error_msg:
|
||||
raise ValueError(
|
||||
f"Conversation format error: {error_msg}. Check message alternation."
|
||||
@@ -892,7 +865,6 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors
|
||||
error_msg = f"Unexpected error in Bedrock converse call: {e}"
|
||||
logging.error(error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
@@ -1338,7 +1310,6 @@ class BedrockCompletion(BaseLLM):
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
# Filter out structured_output from tool_uses returned to executor
|
||||
non_structured_output_tool_uses = [
|
||||
tu for tu in tool_uses if tu.get("name") != STRUCTURED_OUTPUT_TOOL_NAME
|
||||
]
|
||||
@@ -1793,7 +1764,7 @@ class BedrockCompletion(BaseLLM):
|
||||
tool_call_id = message.get("tool_call_id")
|
||||
|
||||
if role == "system":
|
||||
# Extract system message - Converse API handles it separately
|
||||
# Converse API handles system messages separately
|
||||
if system_message:
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
@@ -1835,12 +1806,9 @@ class BedrockCompletion(BaseLLM):
|
||||
{"role": "assistant", "content": bedrock_content}
|
||||
)
|
||||
else:
|
||||
# Convert to Converse API format with proper content structure
|
||||
if isinstance(content, list):
|
||||
# Already formatted as multimodal content blocks
|
||||
converse_messages.append({"role": role, "content": content})
|
||||
else:
|
||||
# String content - wrap in text block
|
||||
text_content = content if content else ""
|
||||
converse_messages.append(
|
||||
{"role": role, "content": [{"text": text_content}]}
|
||||
@@ -2073,7 +2041,6 @@ class BedrockCompletion(BaseLLM):
|
||||
"""Get the context window size for the model."""
|
||||
from crewai.llm import CONTEXT_WINDOW_USAGE_RATIO
|
||||
|
||||
# Context window sizes for common Bedrock models
|
||||
context_windows = {
|
||||
"anthropic.claude-sonnet-4": 200000,
|
||||
"anthropic.claude-opus-4": 200000,
|
||||
@@ -2094,12 +2061,10 @@ class BedrockCompletion(BaseLLM):
|
||||
"deepseek.r1": 32768,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def supports_multimodal(self) -> bool:
|
||||
|
||||
@@ -73,14 +73,12 @@ class GeminiCompletion(BaseLLM):
|
||||
"Interceptors are currently supported for OpenAI and Anthropic providers only."
|
||||
)
|
||||
|
||||
# Normalize stop_sequences from stop kwarg
|
||||
popped = data.pop("stop_sequences", None)
|
||||
seqs = popped if popped is not None else (data.get("stop") or [])
|
||||
if isinstance(seqs, str):
|
||||
seqs = [seqs]
|
||||
data["stop"] = seqs
|
||||
|
||||
# Resolve env vars
|
||||
data["api_key"] = (
|
||||
data.get("api_key")
|
||||
or os.getenv("GOOGLE_API_KEY")
|
||||
@@ -96,7 +94,6 @@ class GeminiCompletion(BaseLLM):
|
||||
use_vx = os.getenv("GOOGLE_GENAI_USE_VERTEXAI", "").lower() == "true"
|
||||
data["use_vertexai"] = use_vx
|
||||
|
||||
# Model-specific settings
|
||||
model = data.get("model", "gemini-2.0-flash-001")
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
data["supports_tools"] = bool(
|
||||
@@ -189,7 +186,6 @@ class GeminiCompletion(BaseLLM):
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
# Determine authentication mode based on available credentials
|
||||
has_api_key = bool(self.api_key)
|
||||
has_project = bool(self.project)
|
||||
|
||||
@@ -466,15 +462,12 @@ class GeminiCompletion(BaseLLM):
|
||||
self.tools = tools
|
||||
config_params: dict[str, Any] = {}
|
||||
|
||||
# Add system instruction if present
|
||||
if system_instruction:
|
||||
# Convert system instruction to Content format
|
||||
system_content = types.Content(
|
||||
role="user", parts=[types.Part.from_text(text=system_instruction)]
|
||||
)
|
||||
config_params["system_instruction"] = system_content
|
||||
|
||||
# Add generation config parameters
|
||||
if self.temperature is not None:
|
||||
config_params["temperature"] = self.temperature
|
||||
if self.top_p is not None:
|
||||
@@ -568,7 +561,6 @@ class GeminiCompletion(BaseLLM):
|
||||
Returns:
|
||||
Tuple of (formatted_contents, system_instruction)
|
||||
"""
|
||||
# Use base class formatting first
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
contents: list[types.Content] = []
|
||||
@@ -578,7 +570,6 @@ class GeminiCompletion(BaseLLM):
|
||||
role = message["role"]
|
||||
content = message["content"]
|
||||
|
||||
# Build parts list from content
|
||||
parts: list[types.Part] = []
|
||||
if isinstance(content, list):
|
||||
for item in content:
|
||||
@@ -601,7 +592,7 @@ class GeminiCompletion(BaseLLM):
|
||||
text_content: str = " ".join(p.text for p in parts if p.text is not None)
|
||||
|
||||
if role == "system":
|
||||
# Extract system instruction - Gemini handles it separately
|
||||
# Gemini handles system instructions separately from content
|
||||
if system_instruction:
|
||||
system_instruction += f"\n\n{text_content}"
|
||||
else:
|
||||
@@ -675,10 +666,9 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
contents.append(types.Content(role="model", parts=tool_parts))
|
||||
else:
|
||||
# Convert role for Gemini (assistant -> model)
|
||||
# Gemini uses "model" instead of "assistant"
|
||||
gemini_role = "model" if role == "assistant" else "user"
|
||||
|
||||
# Create Content object
|
||||
gemini_content = types.Content(role=gemini_role, parts=parts)
|
||||
contents.append(gemini_content)
|
||||
|
||||
@@ -749,7 +739,6 @@ class GeminiCompletion(BaseLLM):
|
||||
"""
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
# Handle structured output validation
|
||||
if response_model:
|
||||
return self._validate_and_emit_structured_output(
|
||||
content=content,
|
||||
@@ -842,12 +831,11 @@ class GeminiCompletion(BaseLLM):
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
# Collect function call parts
|
||||
function_call_parts = [
|
||||
part for part in candidate.content.parts if part.function_call
|
||||
]
|
||||
|
||||
# Check for structured_output pseudo-tool call (used when tools + response_model)
|
||||
# structured_output pseudo-tool is used when tools + response_model are both set
|
||||
if response_model and function_call_parts:
|
||||
for part in function_call_parts:
|
||||
if (
|
||||
@@ -868,7 +856,6 @@ class GeminiCompletion(BaseLLM):
|
||||
usage=usage,
|
||||
)
|
||||
|
||||
# Filter out structured_output from function calls returned to executor
|
||||
non_structured_output_parts = [
|
||||
part
|
||||
for part in function_call_parts
|
||||
@@ -878,8 +865,8 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
]
|
||||
|
||||
# If there are function calls but no available_functions,
|
||||
# return them for the executor to handle (like OpenAI/Anthropic)
|
||||
# Without available_functions, return calls so the executor handles them
|
||||
# (matches OpenAI/Anthropic behavior).
|
||||
if non_structured_output_parts and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=non_structured_output_parts,
|
||||
@@ -891,13 +878,11 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
return non_structured_output_parts
|
||||
|
||||
# Otherwise execute the tools internally
|
||||
for part in candidate.content.parts:
|
||||
if part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
# Skip structured_output - it's handled above
|
||||
if function_name == STRUCTURED_OUTPUT_TOOL_NAME:
|
||||
continue
|
||||
function_args = (
|
||||
@@ -1076,21 +1061,17 @@ class GeminiCompletion(BaseLLM):
|
||||
)
|
||||
return raw_parts
|
||||
|
||||
# Handle completed function calls (excluding structured_output)
|
||||
if non_structured_output_calls and available_functions:
|
||||
for call_data in non_structured_output_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
@@ -1313,13 +1294,11 @@ class GeminiCompletion(BaseLLM):
|
||||
"gemma-3-27b": 128000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size for Gemini models
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens
|
||||
return int(1048576 * CONTEXT_WINDOW_USAGE_RATIO) # 1M tokens default
|
||||
|
||||
@staticmethod
|
||||
def _extract_token_usage(response: GenerateContentResponse) -> dict[str, Any]:
|
||||
|
||||
@@ -247,7 +247,6 @@ class OpenAICompletion(BaseLLM):
|
||||
if not data.get("provider"):
|
||||
data["provider"] = "openai"
|
||||
data["api_key"] = data.get("api_key") or os.getenv("OPENAI_API_KEY")
|
||||
# Extract api_base from kwargs if present
|
||||
if "api_base" not in data:
|
||||
data["api_base"] = None
|
||||
model = data.get("model", "gpt-4o")
|
||||
@@ -333,7 +332,6 @@ class OpenAICompletion(BaseLLM):
|
||||
def to_config_dict(self) -> dict[str, Any]:
|
||||
"""Extend base config with OpenAI-specific fields."""
|
||||
config = super().to_config_dict()
|
||||
# Client-level params (from OpenAI SDK)
|
||||
if self.organization:
|
||||
config["organization"] = self.organization
|
||||
if self.project:
|
||||
@@ -342,7 +340,6 @@ class OpenAICompletion(BaseLLM):
|
||||
config["timeout"] = self.timeout
|
||||
if self.max_retries != 2:
|
||||
config["max_retries"] = self.max_retries
|
||||
# Completion params
|
||||
if self.top_p is not None:
|
||||
config["top_p"] = self.top_p
|
||||
if self.frequency_penalty is not None:
|
||||
@@ -665,7 +662,6 @@ class OpenAICompletion(BaseLLM):
|
||||
for message in messages:
|
||||
if message.get("role") == "system":
|
||||
content = message.get("content", "")
|
||||
# System messages should always have string content
|
||||
content_str = content if isinstance(content, str) else str(content)
|
||||
if instructions:
|
||||
instructions = f"{instructions}\n\n{content_str}"
|
||||
@@ -674,7 +670,7 @@ class OpenAICompletion(BaseLLM):
|
||||
else:
|
||||
input_messages.append(message)
|
||||
|
||||
# Prepare input with optional reasoning items for ZDR chaining
|
||||
# Prepend reasoning items for ZDR (zero-data-retention) chaining when configured
|
||||
final_input: list[Any] = []
|
||||
if self.auto_chain_reasoning and self._last_reasoning_items:
|
||||
final_input.extend(self._last_reasoning_items)
|
||||
@@ -700,7 +696,6 @@ class OpenAICompletion(BaseLLM):
|
||||
elif self.auto_chain and self._last_response_id:
|
||||
params["previous_response_id"] = self._last_response_id
|
||||
|
||||
# Handle include parameter with auto_chain_reasoning support
|
||||
include_items: list[str] = list(self.include) if self.include else []
|
||||
if self.auto_chain_reasoning:
|
||||
if "reasoning.encrypted_content" not in include_items:
|
||||
@@ -819,11 +814,9 @@ class OpenAICompletion(BaseLLM):
|
||||
try:
|
||||
response: Response = self._get_sync_client().responses.create(**params)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
self._last_response_id = response.id
|
||||
|
||||
# Track reasoning items for ZDR auto-chaining
|
||||
if self.auto_chain_reasoning:
|
||||
reasoning_items = self._extract_reasoning_items(response)
|
||||
if reasoning_items:
|
||||
@@ -832,7 +825,6 @@ class OpenAICompletion(BaseLLM):
|
||||
usage = self._extract_responses_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# If parse_tool_outputs is enabled, return structured result
|
||||
if self.parse_tool_outputs:
|
||||
parsed_result = self._extract_builtin_tool_outputs(response)
|
||||
parsed_result.text = self._apply_stop_words(parsed_result.text)
|
||||
@@ -957,11 +949,9 @@ class OpenAICompletion(BaseLLM):
|
||||
**params
|
||||
)
|
||||
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and response.id:
|
||||
self._last_response_id = response.id
|
||||
|
||||
# Track reasoning items for ZDR auto-chaining
|
||||
if self.auto_chain_reasoning:
|
||||
reasoning_items = self._extract_reasoning_items(response)
|
||||
if reasoning_items:
|
||||
@@ -970,7 +960,6 @@ class OpenAICompletion(BaseLLM):
|
||||
usage = self._extract_responses_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# If parse_tool_outputs is enabled, return structured result
|
||||
if self.parse_tool_outputs:
|
||||
parsed_result = self._extract_builtin_tool_outputs(response)
|
||||
parsed_result.text = self._apply_stop_words(parsed_result.text)
|
||||
@@ -1124,10 +1113,8 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
elif event.type == "response.completed":
|
||||
final_response = event.response
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and event.response and event.response.id:
|
||||
self._last_response_id = event.response.id
|
||||
# Track reasoning items for ZDR auto-chaining
|
||||
if self.auto_chain_reasoning and event.response:
|
||||
reasoning_items = self._extract_reasoning_items(event.response)
|
||||
if reasoning_items:
|
||||
@@ -1136,7 +1123,6 @@ class OpenAICompletion(BaseLLM):
|
||||
usage = self._extract_responses_token_usage(event.response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# If parse_tool_outputs is enabled, return structured result
|
||||
if self.parse_tool_outputs and final_response:
|
||||
parsed_result = self._extract_builtin_tool_outputs(final_response)
|
||||
parsed_result.text = self._apply_stop_words(parsed_result.text)
|
||||
@@ -1252,10 +1238,8 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
elif event.type == "response.completed":
|
||||
final_response = event.response
|
||||
# Track response ID for auto-chaining
|
||||
if self.auto_chain and event.response and event.response.id:
|
||||
self._last_response_id = event.response.id
|
||||
# Track reasoning items for ZDR auto-chaining
|
||||
if self.auto_chain_reasoning and event.response:
|
||||
reasoning_items = self._extract_reasoning_items(event.response)
|
||||
if reasoning_items:
|
||||
@@ -1264,7 +1248,6 @@ class OpenAICompletion(BaseLLM):
|
||||
usage = self._extract_responses_token_usage(event.response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# If parse_tool_outputs is enabled, return structured result
|
||||
if self.parse_tool_outputs and final_response:
|
||||
parsed_result = self._extract_builtin_tool_outputs(final_response)
|
||||
parsed_result.text = self._apply_stop_words(parsed_result.text)
|
||||
@@ -1551,7 +1534,6 @@ class OpenAICompletion(BaseLLM):
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
params["tool_choice"] = "auto"
|
||||
|
||||
# Filter out CrewAI-specific parameters that shouldn't go to the API
|
||||
crewai_specific_params = {
|
||||
"callbacks",
|
||||
"available_functions",
|
||||
@@ -1644,8 +1626,7 @@ class OpenAICompletion(BaseLLM):
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
# Without available_functions, return tool_calls so the caller (executor) handles execution
|
||||
if message.tool_calls and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(message.tool_calls),
|
||||
@@ -1657,7 +1638,6 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
# If there are tool_calls and available_functions, execute the tools
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0]
|
||||
if not isinstance(tool_call, ChatCompletionMessageFunctionToolCall):
|
||||
@@ -1732,7 +1712,6 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
raise ConnectionError(error_msg) from e
|
||||
except Exception as e:
|
||||
# Handle context length exceeded and other errors
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
@@ -2033,8 +2012,7 @@ class OpenAICompletion(BaseLLM):
|
||||
choice: Choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# If there are tool_calls but no available_functions, return the tool_calls
|
||||
# This allows the caller (e.g., executor) to handle tool execution
|
||||
# Without available_functions, return tool_calls so the caller (executor) handles execution
|
||||
if message.tool_calls and not available_functions:
|
||||
self._emit_call_completed_event(
|
||||
response=list(message.tool_calls),
|
||||
@@ -2046,7 +2024,6 @@ class OpenAICompletion(BaseLLM):
|
||||
)
|
||||
return list(message.tool_calls)
|
||||
|
||||
# If there are tool_calls and available_functions, execute the tools
|
||||
if message.tool_calls and available_functions:
|
||||
from openai.types.chat.chat_completion_message_function_tool_call import (
|
||||
ChatCompletionMessageFunctionToolCall,
|
||||
@@ -2322,12 +2299,10 @@ class OpenAICompletion(BaseLLM):
|
||||
"o4-mini": 200000,
|
||||
}
|
||||
|
||||
# Find the best match for the model name
|
||||
for model_prefix, size in context_windows.items():
|
||||
if self.model.startswith(model_prefix):
|
||||
return int(size * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_openai_token_usage(
|
||||
@@ -2358,7 +2333,6 @@ class OpenAICompletion(BaseLLM):
|
||||
"""Format messages for OpenAI API."""
|
||||
base_formatted = super()._format_messages(messages)
|
||||
|
||||
# Apply OpenAI-specific formatting
|
||||
formatted_messages: list[LLMMessage] = []
|
||||
|
||||
for message in base_formatted:
|
||||
|
||||
@@ -60,7 +60,6 @@ def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
||||
if not isinstance(tool, dict):
|
||||
raise ValueError("Tool must be a dictionary")
|
||||
|
||||
# Handle nested function schema format (OpenAI/standard)
|
||||
if "function" in tool:
|
||||
function_info = tool["function"]
|
||||
if not isinstance(function_info, dict):
|
||||
@@ -70,12 +69,11 @@ def extract_tool_info(tool: dict[str, Any]) -> tuple[str, str, dict[str, Any]]:
|
||||
description = function_info.get("description", "")
|
||||
parameters = function_info.get("parameters", {})
|
||||
else:
|
||||
# Direct format
|
||||
name = tool.get("name", "")
|
||||
description = tool.get("description", "")
|
||||
parameters = tool.get("parameters", {})
|
||||
|
||||
# Also check for args_schema (Pydantic format)
|
||||
# Fall back to args_schema for Pydantic-defined tools
|
||||
if not parameters and "args_schema" in tool:
|
||||
if hasattr(tool["args_schema"], "model_json_schema"):
|
||||
schema_output = generate_model_description(tool["args_schema"])
|
||||
|
||||
@@ -40,7 +40,6 @@ class _MCPToolResult(NamedTuple):
|
||||
is_error: bool
|
||||
|
||||
|
||||
# MCP Connection timeout constants (in seconds)
|
||||
MCP_CONNECTION_TIMEOUT = 30 # Increased for slow servers
|
||||
MCP_TOOL_EXECUTION_TIMEOUT = 30
|
||||
MCP_DISCOVERY_TIMEOUT = 30 # Increased for slow servers
|
||||
@@ -48,7 +47,6 @@ MCP_MAX_RETRIES = 3
|
||||
|
||||
_T = TypeVar("_T")
|
||||
|
||||
# Simple in-memory cache for MCP tool schemas (duration: 5 minutes)
|
||||
_mcp_schema_cache: dict[str, tuple[list[dict[str, Any]], float]] = {}
|
||||
_cache_ttl = 300 # 5 minutes
|
||||
|
||||
@@ -96,7 +94,6 @@ class MCPClient:
|
||||
self.discovery_timeout = discovery_timeout
|
||||
self.max_retries = max_retries
|
||||
self.cache_tools_list = cache_tools_list
|
||||
# self._logger = logger or logging.getLogger(__name__)
|
||||
self._session: Any = None
|
||||
self._initialized = False
|
||||
self._exit_stack = AsyncExitStack()
|
||||
@@ -152,11 +149,9 @@ class MCPClient:
|
||||
if self.connected:
|
||||
return self
|
||||
|
||||
# Get server info for events
|
||||
server_name, server_url, transport_type = self._get_server_info()
|
||||
is_reconnect = self._was_connected
|
||||
|
||||
# Emit connection started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -177,16 +172,14 @@ class MCPClient:
|
||||
# Always enter transport context via exit stack (it handles already-connected state)
|
||||
await self._exit_stack.enter_async_context(self.transport)
|
||||
|
||||
# Create ClientSession with transport streams
|
||||
self._session = ClientSession(
|
||||
self.transport.read_stream,
|
||||
self.transport.write_stream,
|
||||
)
|
||||
|
||||
# Enter the session's async context manager via exit stack
|
||||
await self._exit_stack.enter_async_context(self._session)
|
||||
|
||||
# Initialize the session (required by MCP protocol)
|
||||
# MCP protocol requires session.initialize() before any other request
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self._session.initialize(),
|
||||
@@ -391,23 +384,19 @@ class MCPClient:
|
||||
if not self.connected:
|
||||
await self.connect()
|
||||
|
||||
# Check cache if enabled
|
||||
use_cache = use_cache if use_cache is not None else self.cache_tools_list
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("tools")
|
||||
if cache_key in _mcp_schema_cache:
|
||||
cached_data, cache_time = _mcp_schema_cache[cache_key]
|
||||
if time.time() - cache_time < _cache_ttl:
|
||||
# Logger removed - return cached data
|
||||
return cached_data
|
||||
|
||||
# List tools with timeout and retries
|
||||
tools = await self._retry_operation(
|
||||
self._list_tools_impl,
|
||||
timeout=self.discovery_timeout,
|
||||
)
|
||||
|
||||
# Cache results if enabled
|
||||
if use_cache:
|
||||
cache_key = self._get_cache_key("tools")
|
||||
_mcp_schema_cache[cache_key] = (tools, time.time())
|
||||
@@ -449,10 +438,8 @@ class MCPClient:
|
||||
arguments = arguments or {}
|
||||
cleaned_arguments = self._clean_tool_arguments(arguments)
|
||||
|
||||
# Get server info for events
|
||||
server_name, server_url, transport_type = self._get_server_info()
|
||||
|
||||
# Emit tool execution started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -542,34 +529,28 @@ class MCPClient:
|
||||
cleaned: dict[str, Any] = {}
|
||||
|
||||
for key, value in arguments.items():
|
||||
# Skip None values
|
||||
if value is None:
|
||||
continue
|
||||
|
||||
# Fix sources array format: convert ["web"] to [{"type": "web"}]
|
||||
# Normalize sources from ["web"] to [{"type": "web"}]
|
||||
if key == "sources" and isinstance(value, list):
|
||||
fixed_sources = []
|
||||
for item in value:
|
||||
if isinstance(item, str):
|
||||
# Convert string to object format
|
||||
fixed_sources.append({"type": item})
|
||||
elif isinstance(item, dict):
|
||||
# Already in correct format
|
||||
fixed_sources.append(item)
|
||||
else:
|
||||
# Keep as is if unknown format
|
||||
fixed_sources.append(item)
|
||||
if fixed_sources:
|
||||
cleaned[key] = fixed_sources
|
||||
continue
|
||||
|
||||
# Recursively clean nested dictionaries
|
||||
if isinstance(value, dict):
|
||||
nested_cleaned = self._clean_tool_arguments(value)
|
||||
if nested_cleaned: # Only add if not empty
|
||||
cleaned[key] = nested_cleaned
|
||||
elif isinstance(value, list):
|
||||
# Clean list items
|
||||
cleaned_list = []
|
||||
for item in value:
|
||||
if isinstance(item, dict):
|
||||
@@ -581,7 +562,6 @@ class MCPClient:
|
||||
if cleaned_list:
|
||||
cleaned[key] = cleaned_list
|
||||
else:
|
||||
# Keep primitive values
|
||||
cleaned[key] = value
|
||||
|
||||
return cleaned
|
||||
@@ -597,7 +577,6 @@ class MCPClient:
|
||||
|
||||
is_error = getattr(result, "isError", False) or False
|
||||
|
||||
# Extract result content
|
||||
if hasattr(result, "content") and result.content:
|
||||
if isinstance(result.content, list) and len(result.content) > 0:
|
||||
content_item = result.content[0]
|
||||
|
||||
@@ -120,5 +120,4 @@ class MCPServerSSE(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Type alias for all MCP server configurations
|
||||
MCPServerConfig = MCPServerStdio | MCPServerHTTP | MCPServerSSE
|
||||
|
||||
@@ -29,7 +29,6 @@ class ToolFilterContext(BaseModel):
|
||||
)
|
||||
|
||||
|
||||
# Type alias for tool filter functions
|
||||
ToolFilter = (
|
||||
Callable[[ToolFilterContext, dict[str, Any]], bool]
|
||||
| Callable[[dict[str, Any]], bool]
|
||||
@@ -79,15 +78,13 @@ class StaticToolFilter:
|
||||
"""
|
||||
tool_name = tool.get("name", "")
|
||||
|
||||
# Blocked tools take precedence
|
||||
# Blocked tools take precedence over allowed tools
|
||||
if self.blocked_tool_names and tool_name in self.blocked_tool_names:
|
||||
return False
|
||||
|
||||
# If allow list exists, tool must be in it
|
||||
if self.allowed_tool_names:
|
||||
return tool_name in self.allowed_tool_names
|
||||
|
||||
# No restrictions - allow all
|
||||
return True
|
||||
|
||||
|
||||
|
||||
@@ -33,10 +33,6 @@ from crewai.memory.utils import join_scope_paths
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# State models
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class ItemState(BaseModel):
|
||||
"""Per-item tracking within a batch."""
|
||||
@@ -51,18 +47,14 @@ class ItemState(BaseModel):
|
||||
private: bool = False
|
||||
# Structural root scope prefix for hierarchical scoping
|
||||
root_scope: str | None = None
|
||||
# Resolved values
|
||||
resolved_scope: str = "/"
|
||||
resolved_categories: list[str] = Field(default_factory=list)
|
||||
resolved_metadata: dict[str, Any] = Field(default_factory=dict)
|
||||
resolved_importance: float = 0.5
|
||||
resolved_source: str | None = None
|
||||
resolved_private: bool = False
|
||||
# Embedding
|
||||
embedding: list[float] = Field(default_factory=list)
|
||||
# Intra-batch dedup
|
||||
dropped: bool = False
|
||||
# Consolidation
|
||||
similar_records: list[MemoryRecord] = Field(default_factory=list)
|
||||
top_similarity: float = 0.0
|
||||
plan: ConsolidationPlan | None = None
|
||||
@@ -74,18 +66,12 @@ class EncodingState(BaseModel):
|
||||
|
||||
id: str = Field(default_factory=lambda: str(uuid4()))
|
||||
items: list[ItemState] = Field(default_factory=list)
|
||||
# Aggregate stats
|
||||
records_inserted: int = 0
|
||||
records_updated: int = 0
|
||||
records_deleted: int = 0
|
||||
items_dropped_dedup: int = 0
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Flow
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
|
||||
class EncodingFlow(Flow[EncodingState]):
|
||||
"""Batch-native encoding pipeline for memory.remember() / remember_many().
|
||||
|
||||
@@ -121,10 +107,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
self._embedder = embedder
|
||||
self._config = config or MemoryConfig()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 1: Batch embed (ONE embedder call)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@start()
|
||||
def batch_embed(self) -> None:
|
||||
"""Embed all items in a single embedder call."""
|
||||
@@ -134,10 +116,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
for item, emb in zip(items, embeddings, strict=False):
|
||||
item.embedding = emb
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 2: Intra-batch dedup (cosine similarity matrix)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@listen(batch_embed)
|
||||
def intra_batch_dedup(self) -> None:
|
||||
"""Drop near-exact duplicates within the batch."""
|
||||
@@ -171,10 +149,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
return 0.0
|
||||
return dot / (norm_a * norm_b)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 3: Parallel find similar (concurrent storage searches)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@listen(intra_batch_dedup)
|
||||
def parallel_find_similar(self) -> None:
|
||||
"""Search storage for similar records, concurrently for all active items."""
|
||||
@@ -244,10 +218,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
item.similar_records = [r for r, _ in raw]
|
||||
item.top_similarity = float(raw[0][1]) if raw else 0.0
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 4: Parallel analyze (N concurrent LLM calls)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@listen(parallel_find_similar)
|
||||
def parallel_analyze(self) -> None:
|
||||
"""Field resolution + consolidation via parallel individual LLM calls.
|
||||
@@ -273,7 +243,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
existing_categories: list[str] = []
|
||||
if any_needs_fields:
|
||||
# Constrain scope/category suggestions to root_scope boundary
|
||||
# Check if any active item has root_scope
|
||||
active_root = next(
|
||||
(it.root_scope for it in items if not it.dropped and it.root_scope),
|
||||
None,
|
||||
@@ -284,7 +253,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
self._storage.list_categories(scope_prefix=active_root).keys()
|
||||
)
|
||||
|
||||
# Classify items and submit LLM calls
|
||||
save_futures: dict[int, Future[MemoryAnalysis]] = {}
|
||||
consol_futures: dict[int, Future[ConsolidationPlan]] = {}
|
||||
|
||||
@@ -302,11 +270,9 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
has_similar = item.top_similarity >= threshold
|
||||
|
||||
if fields_provided and not has_similar:
|
||||
# Group A: fast path
|
||||
self._apply_defaults(item)
|
||||
item.plan = ConsolidationPlan(actions=[], insert_new=True)
|
||||
elif fields_provided and has_similar:
|
||||
# Group B: consolidation only
|
||||
self._apply_defaults(item)
|
||||
consol_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
@@ -316,7 +282,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
self._llm,
|
||||
)
|
||||
elif not fields_provided and not has_similar:
|
||||
# Group C: field resolution only
|
||||
save_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_save,
|
||||
@@ -326,7 +291,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
self._llm,
|
||||
)
|
||||
else:
|
||||
# Group D: both in parallel
|
||||
save_futures[i] = pool.submit(
|
||||
contextvars.copy_context().run,
|
||||
analyze_for_save,
|
||||
@@ -343,13 +307,10 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
self._llm,
|
||||
)
|
||||
|
||||
# Collect field-resolution results
|
||||
for i, future in save_futures.items():
|
||||
analysis = future.result()
|
||||
item = items[i]
|
||||
# Determine inner scope from explicit scope or LLM-inferred
|
||||
inner_scope = item.scope or analysis.suggested_scope or "/"
|
||||
# Join root_scope with inner scope if root_scope is set
|
||||
if item.root_scope:
|
||||
item.resolved_scope = join_scope_paths(item.root_scope, inner_scope)
|
||||
else:
|
||||
@@ -378,7 +339,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
if i not in consol_futures:
|
||||
item.plan = ConsolidationPlan(actions=[], insert_new=True)
|
||||
|
||||
# Collect consolidation results
|
||||
for i, consol_future in consol_futures.items():
|
||||
items[i].plan = consol_future.result()
|
||||
finally:
|
||||
@@ -391,7 +351,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
final resolved_scope.
|
||||
"""
|
||||
inner_scope = item.scope or "/"
|
||||
# Join root_scope with inner scope if root_scope is set
|
||||
if item.root_scope:
|
||||
item.resolved_scope = join_scope_paths(item.root_scope, inner_scope)
|
||||
else:
|
||||
@@ -407,10 +366,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
item.resolved_source = item.source
|
||||
item.resolved_private = item.private
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Step 5: Execute plans (batch re-embed + bulk insert)
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@listen(parallel_analyze)
|
||||
def execute_plans(self) -> None:
|
||||
"""Apply all consolidation plans with batch re-embedding and bulk insert.
|
||||
@@ -423,7 +378,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
items = list(self.state.items)
|
||||
now = datetime.utcnow()
|
||||
|
||||
# --- Deduplicate actions across all items ---
|
||||
# Multiple items may reference the same existing record (because their
|
||||
# similar_records overlap). Collect one action per record_id, first wins.
|
||||
# Also build a map from record_id to the original MemoryRecord for updates.
|
||||
@@ -455,7 +409,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
):
|
||||
dedup_updates[rid] = (i, action.new_content)
|
||||
|
||||
# --- Batch re-embed all update contents in ONE call ---
|
||||
update_list = list(
|
||||
dedup_updates.items()
|
||||
) # [(record_id, (item_idx, new_content)), ...]
|
||||
@@ -468,7 +421,6 @@ class EncodingFlow(Flow[EncodingState]):
|
||||
for (rid, _), emb in zip(update_list, update_embeddings, strict=False):
|
||||
update_emb_map[rid] = emb
|
||||
|
||||
# --- Apply all storage mutations under one lock ---
|
||||
# Hold the write lock for the entire delete + update + insert sequence
|
||||
# so no other pipeline can interleave and cause version conflicts.
|
||||
# The lock is reentrant (RLock), so the individual storage methods
|
||||
|
||||
@@ -80,10 +80,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
self._embedder = embedder
|
||||
self._config = config or MemoryConfig()
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Helpers
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _merged_categories(self) -> list[str] | None:
|
||||
"""Return caller-supplied categories, or None if empty."""
|
||||
return self.state.categories or None
|
||||
@@ -106,10 +102,8 @@ class RecallFlow(Flow[RecallState]):
|
||||
limit=self.state.limit * _RECALL_OVERSAMPLE_FACTOR,
|
||||
min_score=0.0,
|
||||
)
|
||||
# Post-filter by time cutoff
|
||||
if self.state.time_cutoff and raw:
|
||||
raw = [(r, s) for r, s in raw if r.created_at >= self.state.time_cutoff]
|
||||
# Privacy filter
|
||||
if not self.state.include_private and raw:
|
||||
raw = [
|
||||
(r, s)
|
||||
@@ -118,7 +112,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
]
|
||||
return scope, raw
|
||||
|
||||
# Build (embedding, scope) task list
|
||||
tasks: list[tuple[list[float], str]] = [
|
||||
(embedding, scope)
|
||||
for _query_text, embedding in self.state.query_embeddings
|
||||
@@ -182,10 +175,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
self.state.confidence = max((f["top_score"] for f in findings), default=0.0)
|
||||
return findings
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Flow steps
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
@start()
|
||||
def analyze_query_step(self) -> QueryAnalysis:
|
||||
"""Analyze the query, embed distilled sub-queries, extract filters.
|
||||
@@ -204,7 +193,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
skip_llm = query_len < self._config.query_analysis_threshold
|
||||
|
||||
if skip_llm:
|
||||
# Short query: skip LLM, embed raw query directly
|
||||
analysis = QueryAnalysis(
|
||||
keywords=[],
|
||||
suggested_scopes=[],
|
||||
@@ -213,7 +201,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
)
|
||||
self.state.query_analysis = analysis
|
||||
else:
|
||||
# Long query: use LLM to distill sub-queries and extract filters
|
||||
available = self._storage.list_scopes(self.state.scope or "/")
|
||||
if not available:
|
||||
available = ["/"]
|
||||
@@ -230,7 +217,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
)
|
||||
self.state.query_analysis = analysis
|
||||
|
||||
# Parse time_filter into a datetime cutoff
|
||||
if analysis.time_filter:
|
||||
try:
|
||||
self.state.time_cutoff = datetime.fromisoformat(
|
||||
@@ -239,7 +225,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
except ValueError:
|
||||
pass
|
||||
|
||||
# Batch-embed all sub-queries in ONE call
|
||||
queries = (
|
||||
analysis.recall_queries if analysis.recall_queries else [self.state.query]
|
||||
)
|
||||
@@ -249,7 +234,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
(q, emb) for q, emb in zip(queries, embeddings, strict=False) if emb
|
||||
]
|
||||
if not pairs:
|
||||
# Fallback: embed the raw query if distilled queries all failed
|
||||
fallback_emb = embed_texts(self._embedder, [self.state.query])
|
||||
if fallback_emb and fallback_emb[0]:
|
||||
pairs = [(self.state.query, fallback_emb[0])]
|
||||
@@ -386,7 +370,6 @@ class RecallFlow(Flow[RecallState]):
|
||||
matches.sort(key=lambda m: m.score, reverse=True)
|
||||
self.state.final_results = matches[: self.state.limit]
|
||||
|
||||
# Attach evidence gaps to the first result so callers can inspect them
|
||||
if self.state.evidence_gaps and self.state.final_results:
|
||||
self.state.final_results[0].evidence_gaps = list(self.state.evidence_gaps)
|
||||
|
||||
|
||||
@@ -23,7 +23,6 @@ class KickoffTaskOutputsSQLiteStorage:
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "latest_kickoff_task_outputs.db")
|
||||
self.db_path = db_path
|
||||
self._lock_name = f"sqlite:{os.path.realpath(self.db_path)}"
|
||||
|
||||
@@ -197,10 +197,6 @@ class LanceDBStorage:
|
||||
"Scope index creation skipped (may already exist)", exc_info=True
|
||||
)
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Automatic background compaction
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _compact_if_needed(self) -> None:
|
||||
"""Spawn a background compaction on startup.
|
||||
|
||||
|
||||
@@ -141,7 +141,6 @@ class MemoryConfig(BaseModel):
|
||||
compute_composite_score.
|
||||
"""
|
||||
|
||||
# -- Composite score weights --
|
||||
# The recall composite score is:
|
||||
# semantic_weight * similarity + recency_weight * decay + importance_weight * importance
|
||||
# These should sum to ~1.0 for intuitive 0-1 scoring.
|
||||
@@ -183,8 +182,6 @@ class MemoryConfig(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
# -- Consolidation (on save) --
|
||||
|
||||
consolidation_threshold: float = Field(
|
||||
default=0.85,
|
||||
ge=0.0,
|
||||
@@ -215,8 +212,6 @@ class MemoryConfig(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
# -- Save defaults --
|
||||
|
||||
default_importance: float = Field(
|
||||
default=0.5,
|
||||
ge=0.0,
|
||||
@@ -228,7 +223,6 @@ class MemoryConfig(BaseModel):
|
||||
),
|
||||
)
|
||||
|
||||
# -- Recall depth control --
|
||||
# The RecallFlow router uses these thresholds to decide between returning
|
||||
# results immediately ("synthesize") and doing an extra LLM-driven
|
||||
# exploration round ("explore_deeper").
|
||||
@@ -330,7 +324,6 @@ def embed_texts(embedder: Any, texts: list[str]) -> list[list[float]]:
|
||||
"""
|
||||
if not texts:
|
||||
return []
|
||||
# Filter out empty texts, remembering their positions
|
||||
valid: list[tuple[int, str]] = [
|
||||
(i, t) for i, t in enumerate(texts) if t and t.strip()
|
||||
]
|
||||
|
||||
@@ -166,7 +166,6 @@ class Memory(BaseModel):
|
||||
object.__setattr__(
|
||||
new, "__pydantic_extra__", _copy.deepcopy(self.__pydantic_extra__, memo)
|
||||
)
|
||||
# Private attrs: create fresh pool/lock instead of deepcopying
|
||||
private = {}
|
||||
for k, v in (self.__pydantic_private__ or {}).items():
|
||||
if isinstance(v, (ThreadPoolExecutor, threading.Lock)):
|
||||
@@ -264,10 +263,6 @@ class Memory(BaseModel):
|
||||
) from e
|
||||
return self._embedder_instance
|
||||
|
||||
# ------------------------------------------------------------------
|
||||
# Background write queue
|
||||
# ------------------------------------------------------------------
|
||||
|
||||
def _submit_save(self, fn: Any, *args: Any, **kwargs: Any) -> Future[Any]:
|
||||
"""Submit a save operation to the background thread pool.
|
||||
|
||||
@@ -449,7 +444,6 @@ class Memory(BaseModel):
|
||||
start = time.perf_counter()
|
||||
|
||||
# Submit through the save pool for proper serialization,
|
||||
# then immediately wait for the result.
|
||||
future = self._submit_save(
|
||||
self._encode_batch,
|
||||
[content],
|
||||
@@ -676,12 +670,10 @@ class Memory(BaseModel):
|
||||
# so that the search sees all persisted records.
|
||||
self.drain_writes()
|
||||
|
||||
# Apply root_scope as default scope_prefix for read isolation
|
||||
effective_scope = scope
|
||||
if effective_scope is None and self.root_scope:
|
||||
effective_scope = self.root_scope
|
||||
elif effective_scope is not None and self.root_scope:
|
||||
# Nest provided scope under root
|
||||
effective_scope = join_scope_paths(self.root_scope, effective_scope)
|
||||
|
||||
_source = "unified_memory"
|
||||
@@ -709,7 +701,6 @@ class Memory(BaseModel):
|
||||
limit=limit,
|
||||
min_score=0.0,
|
||||
)
|
||||
# Privacy filter
|
||||
if not include_private:
|
||||
raw = [
|
||||
(r, s)
|
||||
@@ -748,7 +739,6 @@ class Memory(BaseModel):
|
||||
)
|
||||
results = flow.state.final_results
|
||||
|
||||
# Update last_accessed for recalled records
|
||||
if results:
|
||||
try:
|
||||
touch = getattr(self._storage, "touch_records", None)
|
||||
|
||||
@@ -30,11 +30,8 @@ def sanitize_scope_name(name: str) -> str:
|
||||
if not name:
|
||||
return "unknown"
|
||||
name = name.lower().strip()
|
||||
# Replace any character that's not alphanumeric, underscore, or hyphen with hyphen
|
||||
name = re.sub(r"[^a-z0-9_-]", "-", name)
|
||||
# Collapse multiple hyphens into one
|
||||
name = re.sub(r"-+", "-", name)
|
||||
# Strip leading/trailing hyphens
|
||||
name = name.strip("-")
|
||||
return name or "unknown"
|
||||
|
||||
@@ -59,12 +56,9 @@ def normalize_scope_path(path: str) -> str:
|
||||
"""
|
||||
if not path or path == "/":
|
||||
return "/"
|
||||
# Collapse multiple slashes
|
||||
path = re.sub(r"/+", "/", path)
|
||||
# Ensure leading slash
|
||||
if not path.startswith("/"):
|
||||
path = "/" + path
|
||||
# Remove trailing slash (unless it's just '/')
|
||||
if len(path) > 1:
|
||||
path = path.rstrip("/")
|
||||
return path
|
||||
@@ -94,7 +88,6 @@ def join_scope_paths(root: str | None, inner: str | None) -> str:
|
||||
>>> join_scope_paths(None, None)
|
||||
'/'
|
||||
"""
|
||||
# Normalize both parts
|
||||
root = root.rstrip("/") if root else ""
|
||||
inner = inner.strip("/") if inner else ""
|
||||
|
||||
|
||||
@@ -213,11 +213,9 @@ def crew(
|
||||
instantiated_agents: list[Agent] = []
|
||||
agent_roles: set[str] = set()
|
||||
|
||||
# Use the preserved task and agent information
|
||||
tasks = self.__crew_metadata__["original_tasks"].items()
|
||||
agents = self.__crew_metadata__["original_agents"].items()
|
||||
|
||||
# Instantiate tasks in order
|
||||
for _, task_method in tasks:
|
||||
task_instance = _call_method(task_method, self)
|
||||
instantiated_tasks.append(task_instance)
|
||||
@@ -226,7 +224,6 @@ def crew(
|
||||
instantiated_agents.append(agent_instance)
|
||||
agent_roles.add(agent_instance.role)
|
||||
|
||||
# Instantiate agents not included by tasks
|
||||
for _, agent_method in agents:
|
||||
agent_instance = _call_method(agent_method, self)
|
||||
if agent_instance.role not in agent_roles:
|
||||
|
||||
@@ -46,7 +46,6 @@ class AgentConfig(TypedDict, total=False):
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core agent attributes (from BaseAgent)
|
||||
role: str
|
||||
goal: str
|
||||
backstory: str
|
||||
@@ -58,35 +57,28 @@ class AgentConfig(TypedDict, total=False):
|
||||
max_tokens: int
|
||||
callbacks: list[str]
|
||||
|
||||
# LLM configuration
|
||||
llm: str
|
||||
function_calling_llm: str
|
||||
use_system_prompt: bool
|
||||
|
||||
# Template configuration
|
||||
system_template: str
|
||||
prompt_template: str
|
||||
response_template: str
|
||||
|
||||
# Tools and handlers (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
step_callback: str
|
||||
cache_handler: str | CacheHandler
|
||||
|
||||
# Code execution
|
||||
allow_code_execution: bool
|
||||
code_execution_mode: Literal["safe", "unsafe"]
|
||||
|
||||
# Context and performance
|
||||
respect_context_window: bool
|
||||
max_retry_limit: int
|
||||
|
||||
# Multimodal and reasoning
|
||||
multimodal: bool
|
||||
reasoning: bool
|
||||
max_reasoning_attempts: int
|
||||
|
||||
# Knowledge configuration
|
||||
knowledge_sources: list[str] | list[Any]
|
||||
knowledge_storage: str | Any
|
||||
knowledge_config: dict[str, Any]
|
||||
@@ -95,7 +87,6 @@ class AgentConfig(TypedDict, total=False):
|
||||
crew_knowledge_context: str
|
||||
knowledge_search_query: str
|
||||
|
||||
# Misc configuration
|
||||
inject_date: bool
|
||||
date_format: str
|
||||
from_repository: str
|
||||
@@ -110,36 +101,29 @@ class TaskConfig(TypedDict, total=False):
|
||||
Fields can be either string references (from YAML) or actual instances (after processing).
|
||||
"""
|
||||
|
||||
# Core task attributes
|
||||
name: str
|
||||
description: str
|
||||
expected_output: str
|
||||
|
||||
# Agent and context
|
||||
agent: str
|
||||
context: list[str]
|
||||
|
||||
# Tools and callbacks (can be string references or instances)
|
||||
tools: list[str] | list[BaseTool]
|
||||
callback: str
|
||||
callbacks: list[str]
|
||||
|
||||
# Output configuration
|
||||
output_json: str
|
||||
output_pydantic: str
|
||||
output_file: str
|
||||
create_directory: bool
|
||||
|
||||
# Execution configuration
|
||||
async_execution: bool
|
||||
human_input: bool
|
||||
markdown: bool
|
||||
|
||||
# Guardrail configuration
|
||||
guardrail: Callable[[TaskOutput], tuple[bool, Any]] | str
|
||||
guardrail_max_retries: int
|
||||
|
||||
# Misc configuration
|
||||
allow_crewai_trigger_context: bool
|
||||
|
||||
|
||||
@@ -811,7 +795,6 @@ class CrewBase(metaclass=_CrewBaseType):
|
||||
Reference: https://stackoverflow.com/questions/11091609/setting-a-class-metaclass-using-a-decorator
|
||||
"""
|
||||
|
||||
# e
|
||||
if TYPE_CHECKING:
|
||||
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
|
||||
@@ -52,7 +52,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
)
|
||||
"""
|
||||
|
||||
# Models that use the legacy vertexai.language_models SDK
|
||||
LEGACY_MODELS: ClassVar[set[str]] = {
|
||||
"textembedding-gecko",
|
||||
"textembedding-gecko@001",
|
||||
@@ -64,7 +63,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
"textembedding-gecko-multilingual@latest",
|
||||
}
|
||||
|
||||
# Models that use the new google-genai SDK
|
||||
GENAI_MODELS: ClassVar[set[str]] = {
|
||||
"gemini-embedding-001",
|
||||
"text-embedding-005",
|
||||
@@ -84,7 +82,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
- task_type: Task type for embeddings (default: "RETRIEVAL_DOCUMENT", new SDK only)
|
||||
- output_dimensionality: Optional output embedding dimension (new SDK only)
|
||||
"""
|
||||
# Handle deprecated 'region' parameter (only if it has a value)
|
||||
region_value = kwargs.pop("region", None) # type: ignore[typeddict-item,unused-ignore]
|
||||
if region_value is not None:
|
||||
warnings.warn(
|
||||
@@ -161,7 +158,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
self._task_type = kwargs.get("task_type", "RETRIEVAL_DOCUMENT")
|
||||
self._output_dimensionality = kwargs.get("output_dimensionality")
|
||||
|
||||
# Initialize client based on authentication mode
|
||||
api_key = kwargs.get("api_key")
|
||||
project_id = kwargs.get("project_id")
|
||||
location: str = str(kwargs.get("location", "us-central1"))
|
||||
@@ -216,7 +212,6 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
|
||||
def _call_genai(self, input: list[str]) -> Embeddings:
|
||||
"""Generate embeddings using the new google-genai SDK."""
|
||||
# Build config for embed_content
|
||||
config_kwargs: dict[str, Any] = {
|
||||
"task_type": self._task_type,
|
||||
}
|
||||
@@ -225,14 +220,12 @@ class GoogleGenAIVertexEmbeddingFunction(EmbeddingFunction[Documents]):
|
||||
|
||||
config = self._EmbedContentConfig(**config_kwargs)
|
||||
|
||||
# Call the embedding API
|
||||
response = self._client.models.embed_content(
|
||||
model=self._model_name,
|
||||
contents=input, # type: ignore[arg-type]
|
||||
config=config,
|
||||
)
|
||||
|
||||
# Extract embeddings from response
|
||||
if response.embeddings is None:
|
||||
raise ValueError("No embeddings returned from the API")
|
||||
embeddings = [emb.values for emb in response.embeddings]
|
||||
|
||||
@@ -19,14 +19,11 @@ def _validate_metadata(v: Any) -> dict[str, Any]:
|
||||
if not isinstance(v, dict):
|
||||
raise ValueError("Metadata must be a dictionary")
|
||||
|
||||
# Validate that all keys are strings
|
||||
for key, value in v.items():
|
||||
if not isinstance(key, str):
|
||||
raise ValueError(f"Metadata keys must be strings, got {type(key)}")
|
||||
|
||||
# Validate nested dictionaries (prevent deeply nested structures)
|
||||
if isinstance(value, dict):
|
||||
# Check for nested dictionaries (limit depth to 1)
|
||||
for nested_key, nested_value in value.items():
|
||||
if not isinstance(nested_key, str):
|
||||
raise ValueError(
|
||||
@@ -35,7 +32,6 @@ def _validate_metadata(v: Any) -> dict[str, Any]:
|
||||
if isinstance(nested_value, dict):
|
||||
raise ValueError("Metadata can only be nested one level deep")
|
||||
|
||||
# Check for maximum metadata size (prevent DoS)
|
||||
if len(str(v)) > 10_000: # Limit metadata size to 10KB
|
||||
raise ValueError("Metadata size exceeds maximum allowed (10KB)")
|
||||
|
||||
@@ -107,7 +103,6 @@ class Fingerprint(BaseModel):
|
||||
"""
|
||||
fingerprint = cls(metadata=metadata or {})
|
||||
if seed:
|
||||
# For seed-based generation, we need to manually set the _uuid_str after creation
|
||||
fingerprint.__dict__["_uuid_str"] = cls._generate_uuid(seed)
|
||||
return fingerprint
|
||||
|
||||
@@ -152,7 +147,6 @@ class Fingerprint(BaseModel):
|
||||
|
||||
fingerprint = cls(metadata=data.get("metadata", {}))
|
||||
|
||||
# For consistency with existing stored fingerprints, we need to manually set these
|
||||
if "uuid_str" in data:
|
||||
fingerprint.__dict__["_uuid_str"] = data["uuid_str"]
|
||||
if "created_at" in data and isinstance(data["created_at"], str):
|
||||
|
||||
@@ -63,7 +63,6 @@ class SkillCacheManager:
|
||||
Path to the stored skill directory.
|
||||
"""
|
||||
skill_dir = self._skill_dir(org, name)
|
||||
# Wipe any previous version
|
||||
if skill_dir.exists():
|
||||
import shutil
|
||||
|
||||
@@ -72,7 +71,6 @@ class SkillCacheManager:
|
||||
|
||||
import io
|
||||
|
||||
# Try tar.gz first, fall back to zip
|
||||
try:
|
||||
with tarfile.open(fileobj=io.BytesIO(archive_bytes), mode="r:gz") as tf:
|
||||
try:
|
||||
|
||||
@@ -105,7 +105,6 @@ def resolve_registry_ref(
|
||||
|
||||
org, name = parse_registry_ref(ref)
|
||||
|
||||
# 1. Project-local: ./skills/{name}/
|
||||
local_path = Path.cwd() / "skills" / name
|
||||
if local_path.is_dir() and (local_path / "SKILL.md").exists():
|
||||
try:
|
||||
@@ -114,7 +113,6 @@ def resolve_registry_ref(
|
||||
except Exception:
|
||||
_logger.debug("Failed to load local skill at %s", local_path, exc_info=True)
|
||||
|
||||
# 2. Global cache
|
||||
cache = SkillCacheManager()
|
||||
cached_path = cache.get_cached_path(org, name)
|
||||
if cached_path is not None and (cached_path / "SKILL.md").exists():
|
||||
@@ -126,7 +124,6 @@ def resolve_registry_ref(
|
||||
"Failed to load cached skill at %s", cached_path, exc_info=True
|
||||
)
|
||||
|
||||
# 3. Download
|
||||
if _is_noninteractive():
|
||||
raise SkillNotCachedError(ref)
|
||||
|
||||
@@ -197,7 +194,6 @@ def download_skill(
|
||||
archive_bytes = dl_response.content
|
||||
else:
|
||||
encoded = data.get("file", "")
|
||||
# Strip data URI prefix if present
|
||||
if "," in encoded:
|
||||
encoded = encoded.split(",", 1)[1]
|
||||
archive_bytes = base64.b64decode(encoded)
|
||||
|
||||
@@ -12,12 +12,10 @@ from crewai.state.provider.sqlite_provider import SqliteProvider
|
||||
|
||||
|
||||
CheckpointEventType = Literal[
|
||||
# Task
|
||||
"task_started",
|
||||
"task_completed",
|
||||
"task_failed",
|
||||
"task_evaluation",
|
||||
# Crew
|
||||
"crew_kickoff_started",
|
||||
"crew_kickoff_completed",
|
||||
"crew_kickoff_failed",
|
||||
@@ -28,7 +26,6 @@ CheckpointEventType = Literal[
|
||||
"crew_test_completed",
|
||||
"crew_test_failed",
|
||||
"crew_test_result",
|
||||
# Agent
|
||||
"agent_execution_started",
|
||||
"agent_execution_completed",
|
||||
"agent_execution_error",
|
||||
@@ -38,7 +35,6 @@ CheckpointEventType = Literal[
|
||||
"agent_evaluation_started",
|
||||
"agent_evaluation_completed",
|
||||
"agent_evaluation_failed",
|
||||
# Flow
|
||||
"flow_created",
|
||||
"flow_started",
|
||||
"flow_finished",
|
||||
@@ -51,24 +47,20 @@ CheckpointEventType = Literal[
|
||||
"human_feedback_received",
|
||||
"flow_input_requested",
|
||||
"flow_input_received",
|
||||
# LLM
|
||||
"llm_call_started",
|
||||
"llm_call_completed",
|
||||
"llm_call_failed",
|
||||
"llm_stream_chunk",
|
||||
"llm_thinking_chunk",
|
||||
# LLM Guardrail
|
||||
"llm_guardrail_started",
|
||||
"llm_guardrail_completed",
|
||||
"llm_guardrail_failed",
|
||||
# Tool
|
||||
"tool_usage_started",
|
||||
"tool_usage_finished",
|
||||
"tool_usage_error",
|
||||
"tool_validate_input_error",
|
||||
"tool_selection_error",
|
||||
"tool_execution_error",
|
||||
# Memory
|
||||
"memory_save_started",
|
||||
"memory_save_completed",
|
||||
"memory_save_failed",
|
||||
@@ -78,18 +70,15 @@ CheckpointEventType = Literal[
|
||||
"memory_retrieval_started",
|
||||
"memory_retrieval_completed",
|
||||
"memory_retrieval_failed",
|
||||
# Knowledge
|
||||
"knowledge_search_query_started",
|
||||
"knowledge_search_query_completed",
|
||||
"knowledge_query_started",
|
||||
"knowledge_query_completed",
|
||||
"knowledge_query_failed",
|
||||
"knowledge_search_query_failed",
|
||||
# Reasoning
|
||||
"agent_reasoning_started",
|
||||
"agent_reasoning_completed",
|
||||
"agent_reasoning_failed",
|
||||
# MCP
|
||||
"mcp_connection_started",
|
||||
"mcp_connection_completed",
|
||||
"mcp_connection_failed",
|
||||
@@ -97,23 +86,19 @@ CheckpointEventType = Literal[
|
||||
"mcp_tool_execution_completed",
|
||||
"mcp_tool_execution_failed",
|
||||
"mcp_config_fetch_failed",
|
||||
# Observation
|
||||
"step_observation_started",
|
||||
"step_observation_completed",
|
||||
"step_observation_failed",
|
||||
"plan_refinement",
|
||||
"plan_replan_triggered",
|
||||
"goal_achieved_early",
|
||||
# Skill
|
||||
"skill_discovery_started",
|
||||
"skill_discovery_completed",
|
||||
"skill_loaded",
|
||||
"skill_activated",
|
||||
"skill_load_failed",
|
||||
# Logging
|
||||
"agent_logs_started",
|
||||
"agent_logs_execution",
|
||||
# A2A
|
||||
"a2a_delegation_started",
|
||||
"a2a_delegation_completed",
|
||||
"a2a_conversation_started",
|
||||
@@ -145,13 +130,11 @@ CheckpointEventType = Literal[
|
||||
"a2a_context_idle",
|
||||
"a2a_context_completed",
|
||||
"a2a_context_pruned",
|
||||
# System
|
||||
"SIGTERM",
|
||||
"SIGINT",
|
||||
"SIGHUP",
|
||||
"SIGTSTP",
|
||||
"SIGCONT",
|
||||
# Env
|
||||
"cc_env",
|
||||
"codex_env",
|
||||
"cursor_env",
|
||||
|
||||
@@ -112,7 +112,6 @@ def _migrate(data: dict[str, Any]) -> dict[str, Any]:
|
||||
current,
|
||||
)
|
||||
|
||||
# --- migrations in version order ---
|
||||
if stored < Version("1.14.6"):
|
||||
for entity in data.get("entities") or []:
|
||||
_backfill_discriminators(entity)
|
||||
|
||||
@@ -338,7 +338,6 @@ class Task(BaseModel):
|
||||
if len(positional_args) != 1:
|
||||
raise ValueError("Guardrail function must accept exactly one parameter")
|
||||
|
||||
# Check return annotation if present, but don't require it
|
||||
return_annotation = sig.return_annotation
|
||||
if return_annotation != inspect.Signature.empty:
|
||||
return_annotation_args = get_args(return_annotation)
|
||||
@@ -505,34 +504,28 @@ class Task(BaseModel):
|
||||
if value is None:
|
||||
return None
|
||||
|
||||
# Basic security checks
|
||||
if ".." in value:
|
||||
raise ValueError(
|
||||
"Path traversal attempts are not allowed in output_file paths"
|
||||
)
|
||||
|
||||
# Check for shell expansion first
|
||||
if value.startswith(("~", "$")):
|
||||
raise ValueError(
|
||||
"Shell expansion characters are not allowed in output_file paths"
|
||||
)
|
||||
|
||||
# Then check other shell special characters
|
||||
if any(char in value for char in ["|", ">", "<", "&", ";"]):
|
||||
raise ValueError(
|
||||
"Shell special characters are not allowed in output_file paths"
|
||||
)
|
||||
|
||||
# Don't strip leading slash if it's a template path with variables
|
||||
if "{" in value or "}" in value:
|
||||
# Validate template variable format
|
||||
template_vars = [part.split("}")[0] for part in value.split("{")[1:]]
|
||||
for var in template_vars:
|
||||
if not var.isidentifier():
|
||||
raise ValueError(f"Invalid template variable name: {var}")
|
||||
return value
|
||||
|
||||
# Strip leading slash for regular paths
|
||||
if value.startswith("/"):
|
||||
return value[1:]
|
||||
return value
|
||||
@@ -761,7 +754,7 @@ class Task(BaseModel):
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
raise e
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
reset_current_task_id(task_id_token)
|
||||
@@ -842,7 +835,6 @@ class Task(BaseModel):
|
||||
guardrail_index=idx,
|
||||
)
|
||||
|
||||
# backwards support
|
||||
if self._guardrail:
|
||||
task_output = self._invoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
@@ -887,7 +879,7 @@ class Task(BaseModel):
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self))
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
raise e
|
||||
finally:
|
||||
clear_task_files(self.id)
|
||||
reset_current_task_id(task_id_token)
|
||||
@@ -1280,7 +1272,6 @@ Follow these guidelines:
|
||||
)
|
||||
|
||||
if guardrail_result.success:
|
||||
# Guardrail passed
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
@@ -1298,9 +1289,7 @@ Follow these guidelines:
|
||||
|
||||
return task_output
|
||||
|
||||
# Guardrail failed
|
||||
if attempt >= self.guardrail_max_retries:
|
||||
# Max retries reached
|
||||
guardrail_name = (
|
||||
f"guardrail {guardrail_index}"
|
||||
if guardrail_index is not None
|
||||
@@ -1328,7 +1317,6 @@ Follow these guidelines:
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
# Regenerate output from agent
|
||||
result = agent.execute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
|
||||
@@ -141,7 +141,7 @@ class Telemetry:
|
||||
e,
|
||||
(SystemExit, KeyboardInterrupt, GeneratorExit, asyncio.CancelledError),
|
||||
):
|
||||
raise # Re-raise the exception to not interfere with system signals
|
||||
raise
|
||||
self.ready = False
|
||||
|
||||
@classmethod
|
||||
@@ -285,14 +285,12 @@ class Telemetry:
|
||||
self._add_attribute(span, "crew_number_of_tasks", len(crew.tasks))
|
||||
self._add_attribute(span, "crew_number_of_agents", len(crew.agents))
|
||||
|
||||
# Add additional fingerprint metadata if available
|
||||
if hasattr(crew, "fingerprint") and crew.fingerprint:
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_fingerprint_created_at",
|
||||
crew.fingerprint.created_at.isoformat(),
|
||||
)
|
||||
# Add fingerprint metadata if it exists
|
||||
if hasattr(crew.fingerprint, "metadata") and crew.fingerprint.metadata:
|
||||
self._add_attribute(
|
||||
span,
|
||||
@@ -337,7 +335,6 @@ class Telemetry:
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in agent.tools or []
|
||||
],
|
||||
# Add agent fingerprint data if sharing crew details
|
||||
"fingerprint": (
|
||||
getattr(
|
||||
getattr(agent, "fingerprint", None),
|
||||
@@ -387,7 +384,6 @@ class Telemetry:
|
||||
sanitize_tool_name(tool.name)
|
||||
for tool in task.tools or []
|
||||
],
|
||||
# Add task fingerprint data if sharing crew details
|
||||
"fingerprint": (
|
||||
task.fingerprint.uuid_str
|
||||
if hasattr(task, "fingerprint") and task.fingerprint
|
||||
@@ -502,7 +498,6 @@ class Telemetry:
|
||||
"task_fingerprint_created_at",
|
||||
task.fingerprint.created_at.isoformat(),
|
||||
)
|
||||
# Add fingerprint metadata if it exists
|
||||
if hasattr(task.fingerprint, "metadata") and task.fingerprint.metadata:
|
||||
self._add_attribute(
|
||||
created_span,
|
||||
@@ -510,7 +505,6 @@ class Telemetry:
|
||||
json.dumps(task.fingerprint.metadata),
|
||||
)
|
||||
|
||||
# Add agent fingerprint if task has an assigned agent
|
||||
if hasattr(task, "agent") and task.agent:
|
||||
add_agent_fingerprint_to_span(
|
||||
created_span, task.agent, self._add_attribute
|
||||
@@ -533,7 +527,6 @@ class Telemetry:
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
|
||||
# Add agent fingerprint if task has an assigned agent
|
||||
if hasattr(task, "agent") and task.agent:
|
||||
add_agent_fingerprint_to_span(span, task.agent, self._add_attribute)
|
||||
|
||||
@@ -560,7 +553,6 @@ class Telemetry:
|
||||
"""
|
||||
|
||||
def _operation() -> None:
|
||||
# Ensure fingerprint data is present on completion span
|
||||
if hasattr(task, "fingerprint") and task.fingerprint:
|
||||
self._add_attribute(span, "task_fingerprint", task.fingerprint.uuid_str)
|
||||
|
||||
@@ -625,7 +617,6 @@ class Telemetry:
|
||||
if llm:
|
||||
self._add_attribute(span, "llm", llm.model)
|
||||
|
||||
# Add agent fingerprint data if available
|
||||
add_agent_fingerprint_to_span(span, agent, self._add_attribute)
|
||||
close_span(span)
|
||||
|
||||
@@ -656,7 +647,6 @@ class Telemetry:
|
||||
if tool_name:
|
||||
self._add_attribute(span, "tool_name", tool_name)
|
||||
|
||||
# Add agent fingerprint data if available
|
||||
add_agent_fingerprint_to_span(span, agent, self._add_attribute)
|
||||
close_span(span)
|
||||
|
||||
|
||||
@@ -27,13 +27,11 @@ def add_agent_fingerprint_to_span(
|
||||
add_attribute_fn: Function to add attributes to the span.
|
||||
"""
|
||||
if agent:
|
||||
# Try to get fingerprint directly
|
||||
if hasattr(agent, "fingerprint") and agent.fingerprint:
|
||||
add_attribute_fn(span, "agent_fingerprint", agent.fingerprint.uuid_str)
|
||||
if hasattr(agent, "role"):
|
||||
add_attribute_fn(span, "agent_role", agent.role)
|
||||
else:
|
||||
# Try to get fingerprint using getattr (for cases where it might not be directly accessible)
|
||||
agent_fingerprint = getattr(
|
||||
getattr(agent, "fingerprint", None), "uuid_str", None
|
||||
)
|
||||
|
||||
@@ -31,9 +31,7 @@ class BaseAgentTool(BaseTool):
|
||||
"""
|
||||
if not name:
|
||||
return ""
|
||||
# Normalize all whitespace (including newlines) to single spaces
|
||||
normalized = " ".join(name.split())
|
||||
# Remove quotes and convert to lowercase
|
||||
return normalized.replace('"', "").casefold()
|
||||
|
||||
@staticmethod
|
||||
@@ -70,7 +68,6 @@ class BaseAgentTool(BaseTool):
|
||||
# have difficulty producing valid JSON.
|
||||
# As a result, we end up with invalid JSON that is truncated like this:
|
||||
# {"task": "....", "coworker": "....
|
||||
# when it should look like this:
|
||||
# {"task": "....", "coworker": "...."}
|
||||
sanitized_name = self.sanitize_agent_name(agent_name)
|
||||
logger.debug(
|
||||
@@ -89,7 +86,6 @@ class BaseAgentTool(BaseTool):
|
||||
f"Found {len(agent)} matching agents for role '{sanitized_name}'"
|
||||
)
|
||||
except (AttributeError, ValueError) as e:
|
||||
# Handle specific exceptions that might occur during role name processing
|
||||
return I18N_DEFAULT.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[
|
||||
@@ -101,7 +97,6 @@ class BaseAgentTool(BaseTool):
|
||||
)
|
||||
|
||||
if not agent:
|
||||
# No matching agent found after sanitization
|
||||
return I18N_DEFAULT.errors("agent_tool_unexisting_coworker").format(
|
||||
coworkers="\n".join(
|
||||
[
|
||||
@@ -124,7 +119,6 @@ class BaseAgentTool(BaseTool):
|
||||
)
|
||||
return selected_agent.execute_task(task_with_assigned_agent, context)
|
||||
except Exception as e:
|
||||
# Handle task creation or execution errors
|
||||
return I18N_DEFAULT.errors("agent_tool_execution_error").format(
|
||||
agent_role=self.sanitize_agent_name(selected_agent.role), error=str(e)
|
||||
)
|
||||
|
||||
@@ -31,13 +31,10 @@ class MCPToolWrapper(BaseTool):
|
||||
tool_schema: Schema information for the tool
|
||||
server_name: Name of the MCP server for prefixing
|
||||
"""
|
||||
# Create tool name with server prefix to avoid conflicts
|
||||
prefixed_name = f"{server_name}_{tool_name}"
|
||||
|
||||
# Handle args_schema properly - BaseTool expects a BaseModel subclass
|
||||
args_schema = tool_schema.get("args_schema")
|
||||
|
||||
# Only pass args_schema if it's provided
|
||||
kwargs = {
|
||||
"name": prefixed_name,
|
||||
"description": tool_schema.get(
|
||||
@@ -50,7 +47,6 @@ class MCPToolWrapper(BaseTool):
|
||||
|
||||
super().__init__(**kwargs)
|
||||
|
||||
# Set instance attributes after super().__init__
|
||||
self._mcp_server_params = mcp_server_params
|
||||
self._original_tool_name = tool_name
|
||||
self._server_name = server_name
|
||||
@@ -99,20 +95,16 @@ class MCPToolWrapper(BaseTool):
|
||||
last_error = None
|
||||
|
||||
for attempt in range(MCP_MAX_RETRIES):
|
||||
# Execute single attempt outside try-except loop structure
|
||||
result, error, should_retry = await self._execute_single_attempt(
|
||||
operation_func, **kwargs
|
||||
)
|
||||
|
||||
# Success case - return immediately
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Non-retryable error - return immediately
|
||||
if not should_retry:
|
||||
return error
|
||||
|
||||
# Retryable error - continue with backoff
|
||||
last_error = error
|
||||
if attempt < MCP_MAX_RETRIES - 1:
|
||||
wait_time = 2**attempt # Exponential backoff
|
||||
@@ -147,7 +139,6 @@ class MCPToolWrapper(BaseTool):
|
||||
except Exception as e:
|
||||
error_str = str(e).lower()
|
||||
|
||||
# Classify errors as retryable or non-retryable
|
||||
if "authentication" in error_str or "unauthorized" in error_str:
|
||||
return None, f"Authentication failed for MCP server: {e!s}", False
|
||||
if "not found" in error_str:
|
||||
|
||||
@@ -134,14 +134,11 @@ class CrewStructuredTool(BaseModel):
|
||||
f"Function {name} must have a docstring if description not provided."
|
||||
)
|
||||
|
||||
# Clean up the description
|
||||
description = textwrap.dedent(description).strip()
|
||||
|
||||
if args_schema is not None:
|
||||
# Use provided schema
|
||||
schema = args_schema
|
||||
elif infer_schema:
|
||||
# Infer schema from function signature
|
||||
schema = cls._create_schema_from_function(name, func)
|
||||
else:
|
||||
raise ValueError(
|
||||
@@ -170,29 +167,21 @@ class CrewStructuredTool(BaseModel):
|
||||
Returns:
|
||||
A Pydantic model class
|
||||
"""
|
||||
# Get function signature
|
||||
sig = inspect.signature(func)
|
||||
|
||||
# Get type hints
|
||||
type_hints = get_type_hints(func)
|
||||
|
||||
# Create field definitions
|
||||
fields = {}
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self/cls for methods
|
||||
if param_name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
# Get type annotation
|
||||
annotation = type_hints.get(param_name, Any)
|
||||
|
||||
# Get default value
|
||||
default = ... if param.default == param.empty else param.default
|
||||
|
||||
# Add field
|
||||
fields[param_name] = (annotation, Field(default=default))
|
||||
|
||||
# Create model
|
||||
schema_name = f"{name.title()}Schema"
|
||||
return create_model(schema_name, **fields) # type: ignore[call-overload, no-any-return]
|
||||
|
||||
@@ -203,20 +192,16 @@ class CrewStructuredTool(BaseModel):
|
||||
sig = inspect.signature(self.func)
|
||||
schema_fields = self.args_schema.model_fields
|
||||
|
||||
# Check required parameters
|
||||
for param_name, param in sig.parameters.items():
|
||||
# Skip self/cls for methods
|
||||
if param_name in ("self", "cls"):
|
||||
continue
|
||||
|
||||
# Skip **kwargs parameters
|
||||
if param.kind in (
|
||||
inspect.Parameter.VAR_KEYWORD,
|
||||
inspect.Parameter.VAR_POSITIONAL,
|
||||
):
|
||||
continue
|
||||
|
||||
# Only validate required parameters without defaults
|
||||
if param.default == inspect.Parameter.empty:
|
||||
if param_name not in schema_fields:
|
||||
raise ValueError(
|
||||
@@ -276,7 +261,6 @@ class CrewStructuredTool(BaseModel):
|
||||
try:
|
||||
if inspect.iscoroutinefunction(self.func):
|
||||
return await self.func(**parsed_args, **kwargs)
|
||||
# Run sync functions in a thread pool
|
||||
import asyncio
|
||||
|
||||
return await asyncio.get_event_loop().run_in_executor(
|
||||
@@ -287,7 +271,6 @@ class CrewStructuredTool(BaseModel):
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
"""Legacy method for compatibility."""
|
||||
# Convert args/kwargs to our expected format
|
||||
if not self.args_schema:
|
||||
return self.func(*args, **kwargs)
|
||||
input_dict = dict(zip(self.args_schema.model_fields.keys(), args, strict=False))
|
||||
|
||||
@@ -107,7 +107,6 @@ class ToolUsage:
|
||||
self.function_calling_llm = function_calling_llm
|
||||
self.fingerprint_context = fingerprint_context or {}
|
||||
|
||||
# Set the maximum parsing attempts for bigger models
|
||||
if (
|
||||
self.function_calling_llm
|
||||
and self.function_calling_llm.model in OPENAI_BIGGER_MODELS
|
||||
@@ -301,7 +300,6 @@ class ToolUsage:
|
||||
result = usage_limit_error
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
result = self._format_result(result=result)
|
||||
# Don't return early - fall through to finally block
|
||||
elif result is None:
|
||||
try:
|
||||
if sanitize_tool_name(calling.tool_name) in [
|
||||
@@ -381,7 +379,6 @@ class ToolUsage:
|
||||
if available_tool and hasattr(
|
||||
available_tool, "_increment_usage_count"
|
||||
):
|
||||
# Use _increment_usage_count to sync count to original tool
|
||||
available_tool._increment_usage_count()
|
||||
if (
|
||||
hasattr(available_tool, "max_usage_count")
|
||||
@@ -534,7 +531,6 @@ class ToolUsage:
|
||||
result = usage_limit_error
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
result = self._format_result(result=result)
|
||||
# Don't return early - fall through to finally block
|
||||
elif result is None:
|
||||
try:
|
||||
if sanitize_tool_name(calling.tool_name) in [
|
||||
@@ -614,7 +610,6 @@ class ToolUsage:
|
||||
if available_tool and hasattr(
|
||||
available_tool, "_increment_usage_count"
|
||||
):
|
||||
# Use _increment_usage_count to sync count to original tool
|
||||
available_tool._increment_usage_count()
|
||||
if (
|
||||
hasattr(available_tool, "max_usage_count")
|
||||
@@ -868,32 +863,27 @@ class ToolUsage:
|
||||
"Tool input must be a valid dictionary in JSON or Python literal format"
|
||||
)
|
||||
|
||||
# Attempt 1: Parse as JSON
|
||||
try:
|
||||
arguments = json.loads(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
pass
|
||||
|
||||
# Attempt 2: Parse as Python literal
|
||||
try:
|
||||
arguments = ast.literal_eval(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (ValueError, SyntaxError):
|
||||
repaired_input = repair_json(tool_input)
|
||||
# Continue to the next parsing attempt
|
||||
|
||||
# Attempt 3: Parse as JSON5
|
||||
try:
|
||||
arguments = json5.loads(tool_input)
|
||||
if isinstance(arguments, dict):
|
||||
return arguments
|
||||
except (JSONDecodeError, ValueError, TypeError):
|
||||
pass # Continue to the next parsing attempt
|
||||
pass
|
||||
|
||||
# Attempt 4: Repair JSON
|
||||
try:
|
||||
repaired_input = str(repair_json(tool_input, skip_json_loads=True))
|
||||
if self.agent and self.agent.verbose:
|
||||
@@ -910,7 +900,6 @@ class ToolUsage:
|
||||
"Tool input must be a valid dictionary in JSON or Python literal format"
|
||||
)
|
||||
self._emit_validate_input_error(error_message)
|
||||
# If all parsing attempts fail, raise an error
|
||||
raise Exception(error_message)
|
||||
|
||||
def _emit_validate_input_error(self, final_error: str) -> None:
|
||||
@@ -923,7 +912,6 @@ class ToolUsage:
|
||||
"agent": self.agent, # Adding agent for fingerprint extraction
|
||||
}
|
||||
|
||||
# Include fingerprint context if available
|
||||
if self.fingerprint_context:
|
||||
tool_selection_data.update(self.fingerprint_context)
|
||||
|
||||
@@ -1000,7 +988,6 @@ class ToolUsage:
|
||||
),
|
||||
}
|
||||
|
||||
# Include fingerprint context if available
|
||||
if self.fingerprint_context:
|
||||
event_data.update(self.fingerprint_context)
|
||||
|
||||
@@ -1017,7 +1004,6 @@ class ToolUsage:
|
||||
"""
|
||||
security_context: dict[str, Any] = {}
|
||||
|
||||
# Add agent fingerprint if available
|
||||
if self.agent and hasattr(self.agent, "security_config"):
|
||||
security_config = getattr(self.agent, "security_config", None)
|
||||
if security_config and hasattr(security_config, "fingerprint"):
|
||||
@@ -1028,7 +1014,6 @@ class ToolUsage:
|
||||
except AttributeError:
|
||||
pass
|
||||
|
||||
# Add task fingerprint if available
|
||||
if self.task and hasattr(self.task, "security_config"):
|
||||
security_config = getattr(self.task, "security_config", None)
|
||||
if security_config and hasattr(security_config, "fingerprint"):
|
||||
|
||||
@@ -112,14 +112,12 @@ def _resolve_dotted_path(path: str) -> Callable[..., Any]:
|
||||
ValueError: If no valid module can be imported from the path.
|
||||
"""
|
||||
parts = path.split(".")
|
||||
# Try importing progressively shorter prefixes as the module.
|
||||
for i in range(len(parts), 0, -1):
|
||||
module_path = ".".join(parts[:i])
|
||||
try:
|
||||
obj: Any = importlib.import_module(module_path)
|
||||
except (ImportError, TypeError, ValueError):
|
||||
continue
|
||||
# Walk the remaining attribute chain.
|
||||
try:
|
||||
for attr in parts[i:]:
|
||||
obj = getattr(obj, attr)
|
||||
|
||||
@@ -169,7 +169,6 @@ def convert_tools_to_openai_schema(
|
||||
tool_name_mapping: dict[str, BaseTool | CrewStructuredTool] = {}
|
||||
|
||||
for tool in tools:
|
||||
# Get the JSON schema for tool parameters
|
||||
parameters: dict[str, Any] = {}
|
||||
if hasattr(tool, "args_schema") and tool.args_schema is not None:
|
||||
try:
|
||||
@@ -177,13 +176,11 @@ def convert_tools_to_openai_schema(
|
||||
tool.args_schema, strip_null_types=False
|
||||
)
|
||||
parameters = schema_output.get("json_schema", {}).get("schema", {})
|
||||
# Remove title and description from schema root as they're redundant
|
||||
parameters.pop("title", None)
|
||||
parameters.pop("description", None)
|
||||
except Exception:
|
||||
parameters = {}
|
||||
|
||||
# Extract original description from formatted description
|
||||
# BaseTool formats description as "Tool Name: ...\nTool Arguments: ...\nTool Description: {original}"
|
||||
description = tool.description
|
||||
if "Tool Description:" in description:
|
||||
@@ -320,7 +317,6 @@ def handle_max_iterations_exceeded(
|
||||
|
||||
messages.append(format_message_for_llm(assistant_message, role="assistant"))
|
||||
|
||||
# Perform one more LLM call to get the final answer
|
||||
answer = llm.call(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
@@ -336,7 +332,6 @@ def handle_max_iterations_exceeded(
|
||||
|
||||
formatted = format_answer(answer=answer)
|
||||
|
||||
# If format_answer returned an AgentAction, convert it to AgentFinish
|
||||
if isinstance(formatted, AgentFinish):
|
||||
return formatted
|
||||
return AgentFinish(
|
||||
@@ -574,7 +569,6 @@ def process_llm_response(
|
||||
"""
|
||||
if not use_stop_words:
|
||||
try:
|
||||
# Preliminary parsing to check for errors.
|
||||
format_answer(answer)
|
||||
except OutputParserError as e:
|
||||
if FINAL_ANSWER_AND_PARSABLE_ACTION_ERROR_MESSAGE in e.error:
|
||||
@@ -778,7 +772,6 @@ def _format_messages_for_summary(messages: list[LLMMessage]) -> str:
|
||||
|
||||
content = msg.get("content")
|
||||
if content is None:
|
||||
# Check for tool_calls on assistant messages with no content
|
||||
tool_calls = msg.get("tool_calls")
|
||||
if tool_calls:
|
||||
tool_names = []
|
||||
@@ -794,7 +787,6 @@ def _format_messages_for_summary(messages: list[LLMMessage]) -> str:
|
||||
else:
|
||||
content = ""
|
||||
elif isinstance(content, list):
|
||||
# Multimodal content blocks — extract text parts
|
||||
text_parts = [
|
||||
block.get("text", "")
|
||||
for block in content
|
||||
@@ -849,8 +841,6 @@ def _split_messages_into_chunks(
|
||||
|
||||
msg_tokens = _estimate_token_count(msg_text)
|
||||
|
||||
# If adding this message would exceed the limit and we already have
|
||||
# messages in the current chunk, start a new chunk
|
||||
if current_chunk and (current_tokens + msg_tokens) > max_tokens:
|
||||
chunks.append(current_chunk)
|
||||
current_chunk = []
|
||||
@@ -939,29 +929,23 @@ def summarize_messages(
|
||||
callbacks: List of callbacks for LLM
|
||||
verbose: Whether to print progress.
|
||||
"""
|
||||
# 1. Extract & preserve file attachments from user messages
|
||||
preserved_files: dict[str, Any] = {}
|
||||
for msg in messages:
|
||||
if msg.get("role") == "user" and msg.get("files"):
|
||||
preserved_files.update(msg["files"])
|
||||
|
||||
# 2. Extract system messages — never summarize them
|
||||
system_messages = [m for m in messages if m.get("role") == "system"]
|
||||
non_system_messages = [m for m in messages if m.get("role") != "system"]
|
||||
|
||||
# If there are only system messages (or no non-system messages), nothing to summarize
|
||||
if not non_system_messages:
|
||||
return
|
||||
|
||||
# 3. Split non-system messages into chunks at message boundaries
|
||||
max_tokens = llm.get_context_window_size()
|
||||
chunks = _split_messages_into_chunks(non_system_messages, max_tokens)
|
||||
|
||||
# 4. Summarize each chunk with role-labeled formatting
|
||||
total_chunks = len(chunks)
|
||||
|
||||
if total_chunks <= 1:
|
||||
# Single chunk — no benefit from async overhead
|
||||
summarized_contents: list[SummaryContent] = []
|
||||
for idx, chunk in enumerate(chunks, 1):
|
||||
if verbose:
|
||||
@@ -984,7 +968,6 @@ def summarize_messages(
|
||||
extracted = _extract_summary_tags(str(summary))
|
||||
summarized_contents.append({"content": extracted})
|
||||
else:
|
||||
# Multiple chunks — summarize in parallel via asyncio
|
||||
if verbose:
|
||||
PRINTER.print(
|
||||
content=f"Summarizing {total_chunks} chunks in parallel...",
|
||||
@@ -1000,7 +983,6 @@ def summarize_messages(
|
||||
|
||||
merged_summary = "\n\n".join(content["content"] for content in summarized_contents)
|
||||
|
||||
# 6. Reconstruct messages: [system messages...] + [summary user message]
|
||||
messages.clear()
|
||||
messages.extend(system_messages)
|
||||
|
||||
@@ -1034,7 +1016,6 @@ def show_agent_logs(
|
||||
agent_role = agent_role.partition("\n")[0]
|
||||
|
||||
if formatted_answer is None:
|
||||
# Start logs
|
||||
printer.print(
|
||||
content=[
|
||||
ColoredText("# Agent: ", "bold_purple"),
|
||||
@@ -1049,7 +1030,6 @@ def show_agent_logs(
|
||||
]
|
||||
)
|
||||
else:
|
||||
# Execution logs
|
||||
printer.print(
|
||||
content=[
|
||||
ColoredText("\n\n# Agent: ", "bold_purple"),
|
||||
@@ -1182,7 +1162,6 @@ DELEGATION_TOOL_NAMES: Final[frozenset[str]] = frozenset(
|
||||
)
|
||||
|
||||
|
||||
# native tool calling tracking for delegation
|
||||
def track_delegation_if_needed(
|
||||
tool_name: str,
|
||||
tool_args: dict[str, Any],
|
||||
@@ -1428,7 +1407,6 @@ def execute_single_native_tool_call(
|
||||
|
||||
call_id, func_name, func_args = info
|
||||
|
||||
# Parse arguments
|
||||
if isinstance(func_args, str):
|
||||
try:
|
||||
args_dict = json.loads(func_args)
|
||||
@@ -1439,14 +1417,12 @@ def execute_single_native_tool_call(
|
||||
|
||||
agent_key = getattr(agent, "key", "unknown") if agent else "unknown"
|
||||
|
||||
# Find original tool for cache_function and result_as_answer
|
||||
original_tool: BaseTool | None = None
|
||||
for tool in original_tools:
|
||||
if sanitize_tool_name(tool.name) == func_name:
|
||||
original_tool = tool
|
||||
break
|
||||
|
||||
# Check cache
|
||||
from_cache = False
|
||||
input_str = json.dumps(args_dict) if args_dict else ""
|
||||
result = "Tool not found"
|
||||
@@ -1461,7 +1437,6 @@ def execute_single_native_tool_call(
|
||||
)
|
||||
from_cache = True
|
||||
|
||||
# Emit tool started event
|
||||
started_at = datetime.now()
|
||||
crewai_event_bus.emit(
|
||||
event_source,
|
||||
@@ -1476,14 +1451,12 @@ def execute_single_native_tool_call(
|
||||
|
||||
track_delegation_if_needed(func_name, args_dict, task)
|
||||
|
||||
# Find structured tool for hooks
|
||||
structured_tool: CrewStructuredTool | None = None
|
||||
for structured in structured_tools or []:
|
||||
if sanitize_tool_name(structured.name) == func_name:
|
||||
structured_tool = structured
|
||||
break
|
||||
|
||||
# Before hooks
|
||||
hook_blocked = False
|
||||
before_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
@@ -1510,7 +1483,6 @@ def execute_single_native_tool_call(
|
||||
tool_func = available_functions[func_name]
|
||||
raw_result = tool_func(**args_dict)
|
||||
|
||||
# Cache result
|
||||
if tools_handler and tools_handler.cache:
|
||||
should_cache = True
|
||||
if original_tool:
|
||||
@@ -1542,7 +1514,6 @@ def execute_single_native_tool_call(
|
||||
)
|
||||
error_event_emitted = True
|
||||
|
||||
# After hooks
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=func_name,
|
||||
tool_input=args_dict,
|
||||
@@ -1561,7 +1532,6 @@ def execute_single_native_tool_call(
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
# Emit tool finished event (only if error event wasn't already emitted)
|
||||
if not error_event_emitted:
|
||||
crewai_event_bus.emit(
|
||||
event_source,
|
||||
@@ -1577,7 +1547,6 @@ def execute_single_native_tool_call(
|
||||
),
|
||||
)
|
||||
|
||||
# Build tool result message
|
||||
tool_message: LLMMessage = {
|
||||
"role": "tool",
|
||||
"tool_call_id": call_id,
|
||||
@@ -1718,7 +1687,6 @@ def _setup_after_llm_call_hooks(
|
||||
|
||||
original_messages = executor_context.messages
|
||||
|
||||
# For Pydantic models, serialize to JSON for hooks
|
||||
if isinstance(answer, BaseModel):
|
||||
pydantic_answer = answer
|
||||
hook_response: str = pydantic_answer.model_dump_json()
|
||||
@@ -1756,9 +1724,7 @@ def _setup_after_llm_call_hooks(
|
||||
else:
|
||||
executor_context.messages = []
|
||||
|
||||
# If hooks modified the response, update answer accordingly
|
||||
if pydantic_answer is not None:
|
||||
# For Pydantic models, reparse the JSON if it was modified
|
||||
if hook_response != original_json:
|
||||
try:
|
||||
model_class: type[BaseModel] = type(pydantic_answer)
|
||||
@@ -1770,7 +1736,6 @@ def _setup_after_llm_call_hooks(
|
||||
color="yellow",
|
||||
)
|
||||
else:
|
||||
# For string responses, use the hook-modified response
|
||||
answer = hook_response
|
||||
|
||||
return answer
|
||||
|
||||
@@ -19,8 +19,6 @@ def process_config(
|
||||
if not config:
|
||||
return values
|
||||
|
||||
# Copy values from config (originally from YAML) to the model's attributes.
|
||||
# Only copy if the attribute isn't already set, preserving any explicitly defined values.
|
||||
for key, value in config.items():
|
||||
if key not in model_class.model_fields or values.get(key) is not None:
|
||||
continue
|
||||
@@ -33,6 +31,5 @@ def process_config(
|
||||
else:
|
||||
values[key] = value
|
||||
|
||||
# Remove the config from values to avoid duplicate processing
|
||||
values.pop("config", None)
|
||||
return values
|
||||
|
||||
@@ -40,14 +40,9 @@ class CrewJSONEncoder(json.JSONEncoder):
|
||||
def _handle_pydantic_model(obj: BaseModel) -> str | Any:
|
||||
try:
|
||||
data = obj.model_dump()
|
||||
# Remove circular references
|
||||
for key, value in data.items():
|
||||
if isinstance(value, BaseModel):
|
||||
data[key] = str(
|
||||
value
|
||||
) # Convert nested models to string representation
|
||||
data[key] = str(value)
|
||||
return data
|
||||
except RecursionError:
|
||||
return str(
|
||||
obj
|
||||
) # Fall back to string representation if circular reference is detected
|
||||
return str(obj)
|
||||
|
||||
@@ -137,7 +137,6 @@ class CrewEvaluator:
|
||||
avg_score = task_averages[task_index]
|
||||
agents = list(task.processed_by_agents)
|
||||
|
||||
# Add the task row with the first agent
|
||||
table.add_row(
|
||||
f"Task {task_index + 1}",
|
||||
*[f"{score:.1f}" for score in task_scores],
|
||||
@@ -145,15 +144,12 @@ class CrewEvaluator:
|
||||
f"- {agents[0]}" if agents else "",
|
||||
)
|
||||
|
||||
# Add rows for additional agents
|
||||
for agent in agents[1:]:
|
||||
table.add_row("", "", "", "", "", f"- {agent}")
|
||||
|
||||
# Add a blank separator row if it's not the last task
|
||||
if task_index < len(self.crew.tasks) - 1:
|
||||
table.add_row("", "", "", "", "", "")
|
||||
|
||||
# Add Crew and Execution Time rows
|
||||
crew_scores = [
|
||||
sum(self.tasks_scores[run]) / len(self.tasks_scores[run])
|
||||
for run in range(1, len(self.tasks_scores) + 1)
|
||||
|
||||
@@ -50,23 +50,17 @@ class FileHandler:
|
||||
Raises:
|
||||
ValueError: If file_path is neither a string nor a boolean.
|
||||
"""
|
||||
if file_path is True: # File path is boolean True
|
||||
if file_path is True:
|
||||
self._path = os.path.join(os.curdir, "logs.txt")
|
||||
|
||||
elif isinstance(file_path, str): # File path is a string
|
||||
elif isinstance(file_path, str):
|
||||
if file_path.endswith((".json", ".txt")):
|
||||
self._path = (
|
||||
file_path # No modification if the file ends with .json or .txt
|
||||
)
|
||||
self._path = file_path
|
||||
else:
|
||||
self._path = (
|
||||
file_path + ".txt"
|
||||
) # Append .txt if the file doesn't end with .json or .txt
|
||||
self._path = file_path + ".txt"
|
||||
|
||||
else:
|
||||
raise ValueError(
|
||||
"file_path must be a string or boolean."
|
||||
) # Handle the case where file_path isn't valid
|
||||
raise ValueError("file_path must be a string or boolean.")
|
||||
|
||||
def log(self, **kwargs: Unpack[LogEntry]) -> None:
|
||||
"""Log data with structured fields.
|
||||
@@ -96,14 +90,11 @@ class FileHandler:
|
||||
log_entry = {"timestamp": now, **kwargs}
|
||||
|
||||
if self._path.endswith(".json"):
|
||||
# Append log in JSON format
|
||||
try:
|
||||
# Try reading existing content to avoid overwriting
|
||||
with open(self._path, encoding="utf-8") as read_file:
|
||||
existing_data = json.load(read_file)
|
||||
existing_data.append(log_entry)
|
||||
except (json.JSONDecodeError, FileNotFoundError):
|
||||
# If no valid JSON or file doesn't exist, start with an empty list
|
||||
existing_data = [log_entry]
|
||||
|
||||
with open(self._path, "w", encoding="utf-8") as write_file:
|
||||
@@ -111,7 +102,6 @@ class FileHandler:
|
||||
write_file.write("\n")
|
||||
|
||||
else:
|
||||
# Append log in plain text format
|
||||
message = (
|
||||
f"{now}: "
|
||||
+ ", ".join(
|
||||
|
||||
@@ -113,13 +113,11 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
|
||||
api_base = os.environ.get("API_BASE") or os.environ.get("AZURE_API_BASE")
|
||||
|
||||
# Synchronize base_url and api_base if one is populated and the other is not
|
||||
if base_url and not api_base:
|
||||
api_base = base_url
|
||||
elif api_base and not base_url:
|
||||
base_url = api_base
|
||||
|
||||
# Initialize llm_params dictionary
|
||||
llm_params: dict[str, Any] = {
|
||||
"model": model,
|
||||
"temperature": temperature,
|
||||
@@ -158,7 +156,6 @@ def _llm_via_environment_or_fallback() -> LLM | None:
|
||||
if key_name and key_name not in unaccepted_attributes:
|
||||
env_value = os.environ.get(key_name)
|
||||
if env_value:
|
||||
# Map environment variable names to recognized parameters
|
||||
param_key = _normalize_key_name(key_name.lower())
|
||||
llm_params[param_key] = env_value
|
||||
elif isinstance(env_var, dict):
|
||||
|
||||
@@ -8,7 +8,6 @@ from uuid import uuid4
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
|
||||
# Todo status type
|
||||
TodoStatus = Literal["pending", "running", "completed", "failed"]
|
||||
|
||||
|
||||
|
||||
@@ -78,9 +78,6 @@ class Prompts(BaseModel):
|
||||
A dictionary containing the constructed prompt(s).
|
||||
"""
|
||||
slices: list[COMPONENTS] = ["role_playing"]
|
||||
# When using native tool calling with tools, use native_tools instructions
|
||||
# When using ReAct pattern with tools, use tools instructions
|
||||
# When no tools are available, use no_tools instructions
|
||||
if self.has_tools:
|
||||
if not self.use_native_tool_calling:
|
||||
slices.append("tools")
|
||||
@@ -88,7 +85,6 @@ class Prompts(BaseModel):
|
||||
slices.append("no_tools")
|
||||
system: str = self._build_prompt(slices) + self._build_skill_block()
|
||||
|
||||
# Determine which task slice to use:
|
||||
task_slice: COMPONENTS
|
||||
if self.use_native_tool_calling:
|
||||
task_slice = "native_task"
|
||||
@@ -156,13 +152,11 @@ class Prompts(BaseModel):
|
||||
"""
|
||||
prompt: str
|
||||
if not system_template or not prompt_template:
|
||||
# If any of the required templates are missing, fall back to the default format
|
||||
prompt_parts: list[str] = [
|
||||
I18N_DEFAULT.slice(component) for component in components
|
||||
]
|
||||
prompt = "".join(prompt_parts)
|
||||
else:
|
||||
# All templates are provided, use them
|
||||
template_parts: list[str] = [
|
||||
I18N_DEFAULT.slice(component)
|
||||
for component in components
|
||||
@@ -174,7 +168,6 @@ class Prompts(BaseModel):
|
||||
prompt = prompt_template.replace(
|
||||
"{{ .Prompt }}", "".join(I18N_DEFAULT.slice("task"))
|
||||
)
|
||||
# Handle missing response_template
|
||||
if response_template:
|
||||
response: str = response_template.split("{{ .Response }}")[0]
|
||||
prompt = f"{system}\n{prompt}\n{response}"
|
||||
|
||||
@@ -43,7 +43,6 @@ class AgentReasoningOutput(BaseModel):
|
||||
plan: ReasoningPlan = Field(description="The reasoning plan for the task.")
|
||||
|
||||
|
||||
# Aliases for backward compatibility
|
||||
PlanningPlan = ReasoningPlan
|
||||
AgentPlanningOutput = AgentReasoningOutput
|
||||
|
||||
@@ -138,7 +137,6 @@ class AgentReasoning:
|
||||
"""
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
# Use task attributes if available, otherwise use provided values
|
||||
self._description = description or (
|
||||
task.description if task else "Complete the requested task"
|
||||
)
|
||||
@@ -169,7 +167,6 @@ class AgentReasoning:
|
||||
|
||||
if self.agent.planning_config is not None:
|
||||
return self.agent.planning_config
|
||||
# Fallback when planning is enabled without an explicit config
|
||||
max_attempts = getattr(self.agent, "max_reasoning_attempts", None)
|
||||
if max_attempts is not None:
|
||||
return PlanningConfig(max_attempts=max_attempts)
|
||||
@@ -196,7 +193,6 @@ class AgentReasoning:
|
||||
"""
|
||||
task_id = str(self.task.id) if self.task else "kickoff"
|
||||
|
||||
# Emit a planning started event (attempt 1)
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
@@ -208,7 +204,6 @@ class AgentReasoning:
|
||||
),
|
||||
)
|
||||
except Exception: # noqa: S110
|
||||
# Ignore event bus errors to avoid breaking execution
|
||||
pass
|
||||
|
||||
try:
|
||||
@@ -229,7 +224,6 @@ class AgentReasoning:
|
||||
|
||||
return output
|
||||
except Exception as e:
|
||||
# Emit planning failed event
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
@@ -302,7 +296,6 @@ class AgentReasoning:
|
||||
while not ready and (max_attempts is None or attempt < max_attempts):
|
||||
attempt += 1
|
||||
|
||||
# Emit event for each refinement attempt
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
@@ -328,9 +321,8 @@ class AgentReasoning:
|
||||
plan_type="refine_plan",
|
||||
)
|
||||
plan, ready = self._parse_planning_response(str(response))
|
||||
steps = [] # No structured steps from text parsing
|
||||
steps = []
|
||||
|
||||
# Emit completed event for this refinement attempt
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self.agent,
|
||||
@@ -373,7 +365,6 @@ class AgentReasoning:
|
||||
try:
|
||||
system_prompt = self._get_system_prompt()
|
||||
|
||||
# Prepare a simple callable that just returns the tool arguments as JSON
|
||||
def _create_reasoning_plan(
|
||||
plan: str,
|
||||
steps: list[dict[str, Any]] | None = None,
|
||||
@@ -395,7 +386,6 @@ class AgentReasoning:
|
||||
try:
|
||||
result = json.loads(response)
|
||||
if "plan" in result and "ready" in result:
|
||||
# Parse steps from the response
|
||||
steps: list[PlanStep] = []
|
||||
raw_steps = result.get("steps", [])
|
||||
try:
|
||||
@@ -488,11 +478,9 @@ class AgentReasoning:
|
||||
if self.config.system_prompt is not None:
|
||||
return self.config.system_prompt
|
||||
|
||||
# Try new "planning" section first, fall back to "reasoning" for compatibility
|
||||
try:
|
||||
return I18N_DEFAULT.retrieve("planning", "system_prompt")
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return I18N_DEFAULT.retrieve("reasoning", "initial_plan").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
@@ -515,7 +503,6 @@ class AgentReasoning:
|
||||
"""
|
||||
available_tools = self._format_available_tools()
|
||||
|
||||
# Use custom prompt if provided
|
||||
if self.config.plan_prompt is not None:
|
||||
return self.config.plan_prompt.format(
|
||||
role=self.agent.role,
|
||||
@@ -527,7 +514,6 @@ class AgentReasoning:
|
||||
max_steps=self.config.max_steps,
|
||||
)
|
||||
|
||||
# Try new "planning" section first
|
||||
try:
|
||||
return I18N_DEFAULT.retrieve("planning", "create_plan_prompt").format(
|
||||
description=self.description,
|
||||
@@ -536,7 +522,6 @@ class AgentReasoning:
|
||||
max_steps=self.config.max_steps,
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return I18N_DEFAULT.retrieve("reasoning", "create_plan_prompt").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
@@ -553,7 +538,6 @@ class AgentReasoning:
|
||||
Comma-separated list of tool names.
|
||||
"""
|
||||
try:
|
||||
# Try task tools first, then agent tools
|
||||
tools = []
|
||||
if self.task:
|
||||
tools = self.task.tools or []
|
||||
@@ -574,7 +558,6 @@ class AgentReasoning:
|
||||
Returns:
|
||||
The refine prompt.
|
||||
"""
|
||||
# Use custom prompt if provided
|
||||
if self.config.refine_prompt is not None:
|
||||
return self.config.refine_prompt.format(
|
||||
role=self.agent.role,
|
||||
@@ -584,13 +567,11 @@ class AgentReasoning:
|
||||
max_steps=self.config.max_steps,
|
||||
)
|
||||
|
||||
# Try new "planning" section first
|
||||
try:
|
||||
return I18N_DEFAULT.retrieve("planning", "refine_plan_prompt").format(
|
||||
current_plan=current_plan,
|
||||
)
|
||||
except (KeyError, AttributeError):
|
||||
# Fallback to reasoning section for backward compatibility
|
||||
return I18N_DEFAULT.retrieve("reasoning", "refine_plan_prompt").format(
|
||||
role=self.agent.role,
|
||||
goal=self.agent.goal,
|
||||
@@ -617,7 +598,6 @@ class AgentReasoning:
|
||||
return plan, ready
|
||||
|
||||
|
||||
# Alias for backward compatibility
|
||||
AgentPlanning = AgentReasoning
|
||||
|
||||
|
||||
|
||||
@@ -99,7 +99,6 @@ def interpolate_only(
|
||||
ValueError: If a value contains unsupported types or a template variable is missing
|
||||
"""
|
||||
|
||||
# Validation function for recursive type checking
|
||||
def _validate_type(validate_value: Any) -> None:
|
||||
if validate_value is None:
|
||||
return
|
||||
@@ -118,7 +117,6 @@ def interpolate_only(
|
||||
"Only str, int, float, bool, dict, and list are allowed."
|
||||
)
|
||||
|
||||
# Validate all input values
|
||||
for key, value in inputs.items():
|
||||
try:
|
||||
_validate_type(value)
|
||||
@@ -137,14 +135,12 @@ def interpolate_only(
|
||||
variables = _VARIABLE_PATTERN.findall(input_string)
|
||||
result = input_string
|
||||
|
||||
# Check if all variables exist in inputs
|
||||
missing_vars = [var for var in variables if var not in inputs]
|
||||
if missing_vars:
|
||||
raise KeyError(
|
||||
f"Template variable '{missing_vars[0]}' not found in inputs dictionary"
|
||||
)
|
||||
|
||||
# Replace each variable with its value
|
||||
for var in variables:
|
||||
if var in inputs:
|
||||
placeholder = "{" + var + "}"
|
||||
|
||||
@@ -13,7 +13,6 @@ from crewai.agents.agent_builder.utilities.base_token_process import TokenProces
|
||||
from crewai.utilities.logger_utils import suppress_warnings
|
||||
|
||||
|
||||
# Check if litellm is available for callback integration
|
||||
try:
|
||||
from litellm.integrations.custom_logger import CustomLogger as LiteLLMCustomLogger
|
||||
|
||||
|
||||
@@ -245,7 +245,6 @@ def execute_tool_and_check_finality(
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
# Execute after_tool_call hooks
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result: str = tool_result
|
||||
try:
|
||||
|
||||
@@ -49,9 +49,6 @@ def _pydantic_valid_event(data: dict[str, Any]) -> bool:
|
||||
return False
|
||||
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Valid server-to-client payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_SERVER_MESSAGES: list[dict[str, Any]] = [
|
||||
{
|
||||
@@ -126,9 +123,6 @@ VALID_SERVER_MESSAGES: list[dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Invalid server-to-client payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INVALID_SERVER_MESSAGES: list[dict[str, Any]] = [
|
||||
{},
|
||||
@@ -141,9 +135,6 @@ INVALID_SERVER_MESSAGES: list[dict[str, Any]] = [
|
||||
{"unknownType": {"surfaceId": "s1"}},
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Valid client-to-server payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_CLIENT_EVENTS: list[dict[str, Any]] = [
|
||||
{
|
||||
@@ -169,9 +160,6 @@ VALID_CLIENT_EVENTS: list[dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Invalid client-to-server payloads
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
INVALID_CLIENT_EVENTS: list[dict[str, Any]] = [
|
||||
{},
|
||||
@@ -188,9 +176,7 @@ INVALID_CLIENT_EVENTS: list[dict[str, Any]] = [
|
||||
},
|
||||
]
|
||||
|
||||
# ---------------------------------------------------------------------------
|
||||
# Catalog component payloads (validated structurally)
|
||||
# ---------------------------------------------------------------------------
|
||||
|
||||
VALID_COMPONENTS: dict[str, dict[str, Any]] = {
|
||||
"Text": {"text": {"literalString": "hello"}, "usageHint": "h1"},
|
||||
|
||||
@@ -13,7 +13,6 @@ class ConcreteAgentAdapter(BaseAgentAdapter):
|
||||
def configure_tools(
|
||||
self, tools: list[BaseTool] | None = None, **kwargs: Any
|
||||
) -> None:
|
||||
# Simple implementation for testing
|
||||
self.tools = tools or []
|
||||
|
||||
def execute_task(
|
||||
@@ -94,7 +93,6 @@ def test_configure_tools_method_exists():
|
||||
adapter = ConcreteAgentAdapter(
|
||||
role="test role", goal="test goal", backstory="test backstory"
|
||||
)
|
||||
# Create dummy tools if needed, or pass None
|
||||
tools = []
|
||||
adapter.configure_tools(tools)
|
||||
assert hasattr(adapter, "tools")
|
||||
@@ -107,13 +105,11 @@ def test_configure_structured_output_method_exists():
|
||||
role="test role", goal="test goal", backstory="test backstory"
|
||||
)
|
||||
|
||||
# Define a dummy structure or pass None/Any
|
||||
class DummyOutput(BaseModel):
|
||||
data: str
|
||||
|
||||
structured_output = DummyOutput
|
||||
adapter.configure_structured_output(structured_output)
|
||||
# Add assertions here if configure_structured_output modifies state
|
||||
# For now, just ensuring it runs without error is sufficient
|
||||
|
||||
|
||||
|
||||
@@ -64,7 +64,6 @@ def test_trust_remote_completion_status_true_returns_directly():
|
||||
"history": [],
|
||||
}
|
||||
|
||||
# This should return directly without checking LLM response
|
||||
result = _delegate_to_a2a(
|
||||
self=agent,
|
||||
agent_response=MockResponse(),
|
||||
@@ -140,7 +139,6 @@ def test_trust_remote_completion_status_false_continues_conversation():
|
||||
original_task_description="test",
|
||||
)
|
||||
|
||||
# Should call original_fn to get server response
|
||||
assert call_count >= 1
|
||||
assert result == "Server final answer"
|
||||
|
||||
|
||||
@@ -28,26 +28,21 @@ from crewai.utilities import RPMController
|
||||
|
||||
|
||||
def test_agent_llm_creation_with_env_vars():
|
||||
# Store original environment variables
|
||||
original_api_key = os.environ.get("OPENAI_API_KEY")
|
||||
original_api_base = os.environ.get("OPENAI_API_BASE")
|
||||
original_model_name = os.environ.get("OPENAI_MODEL_NAME")
|
||||
|
||||
# Set up environment variables
|
||||
os.environ["OPENAI_API_KEY"] = "test_api_key"
|
||||
os.environ["OPENAI_API_BASE"] = "https://test-api-base.com"
|
||||
os.environ["OPENAI_MODEL_NAME"] = "gpt-4-turbo"
|
||||
|
||||
# Create an agent without specifying LLM
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
|
||||
# Check if LLM is created correctly
|
||||
assert isinstance(agent.llm, BaseLLM)
|
||||
assert agent.llm.model == "gpt-4-turbo"
|
||||
assert agent.llm.api_key == "test_api_key"
|
||||
assert agent.llm.base_url == "https://test-api-base.com"
|
||||
|
||||
# Clean up environment variables
|
||||
del os.environ["OPENAI_API_KEY"]
|
||||
del os.environ["OPENAI_API_BASE"]
|
||||
del os.environ["OPENAI_MODEL_NAME"]
|
||||
@@ -59,16 +54,13 @@ def test_agent_llm_creation_with_env_vars():
|
||||
if original_model_name:
|
||||
os.environ["OPENAI_MODEL_NAME"] = original_model_name
|
||||
|
||||
# Create an agent without specifying LLM
|
||||
agent = Agent(role="test role", goal="test goal", backstory="test backstory")
|
||||
|
||||
# Check if LLM is created correctly
|
||||
assert isinstance(agent.llm, BaseLLM)
|
||||
assert agent.llm.model != "gpt-4-turbo"
|
||||
assert agent.llm.api_key != "test_api_key"
|
||||
assert agent.llm.base_url != "https://test-api-base.com"
|
||||
|
||||
# Restore original environment variables
|
||||
if original_api_key:
|
||||
os.environ["OPENAI_API_KEY"] = original_api_key
|
||||
if original_api_base:
|
||||
@@ -389,7 +381,6 @@ def test_agent_custom_max_iterations():
|
||||
assert result is not None
|
||||
assert isinstance(result, str)
|
||||
assert len(result) > 0
|
||||
# With max_iter=1, exactly two provider calls are expected:
|
||||
# one inside the reasoning loop and one for the forced final answer.
|
||||
assert call_count == 2
|
||||
|
||||
@@ -584,7 +575,6 @@ def test_agent_without_max_rpm_respects_crew_rpm(capsys):
|
||||
with patch.object(RPMController, "_wait_for_next_minute") as moveon:
|
||||
moveon.return_value = True
|
||||
result = crew.kickoff()
|
||||
# Verify the crew executed and RPM limit was triggered
|
||||
assert result is not None
|
||||
assert moveon.called
|
||||
|
||||
@@ -698,7 +688,6 @@ def test_agent_definition_based_on_dict():
|
||||
assert agent.tools == []
|
||||
|
||||
|
||||
# test for human input
|
||||
@pytest.mark.vcr()
|
||||
@pytest.mark.filterwarnings("ignore::DeprecationWarning")
|
||||
def test_agent_human_input():
|
||||
@@ -725,8 +714,8 @@ def test_agent_human_input():
|
||||
# Side effect function for _prompt_input to simulate multiple feedback iterations
|
||||
feedback_responses = iter(
|
||||
[
|
||||
"Don't say hi, say Hello instead!", # First feedback: instruct change
|
||||
"", # Second feedback: empty string signals acceptance
|
||||
"Don't say hi, say Hello instead!",
|
||||
"",
|
||||
]
|
||||
)
|
||||
|
||||
@@ -746,13 +735,11 @@ def test_agent_human_input():
|
||||
return_value=AgentFinish(output="Hello", thought="", text=""),
|
||||
),
|
||||
):
|
||||
# Execute the task
|
||||
output = agent.execute_task(task)
|
||||
|
||||
# Assertions to ensure the agent behaves correctly.
|
||||
# It should have requested feedback twice.
|
||||
assert mock_prompt_input.call_count == 2
|
||||
# The final result should be processed to "Hello"
|
||||
assert output.strip().lower() == "hello"
|
||||
|
||||
|
||||
@@ -844,13 +831,10 @@ Thought:<|eot_id|>
|
||||
with patch.object(AgentExecutor, "_format_prompt") as mock_format_prompt:
|
||||
mock_format_prompt.return_value = expected_prompt
|
||||
|
||||
# Trigger the _format_prompt method
|
||||
agent.agent_executor._format_prompt("dummy_prompt", {})
|
||||
|
||||
# Assert that _format_prompt was called
|
||||
mock_format_prompt.assert_called_once()
|
||||
|
||||
# Assert that the returned prompt matches the expected prompt
|
||||
assert mock_format_prompt.return_value == expected_prompt
|
||||
|
||||
|
||||
@@ -1194,7 +1178,6 @@ def test_agent_with_callbacks():
|
||||
)
|
||||
|
||||
assert isinstance(agent.llm, BaseLLM)
|
||||
# All LLM implementations now support callbacks consistently
|
||||
assert hasattr(agent.llm, "callbacks")
|
||||
assert len(agent.llm.callbacks) == 1
|
||||
assert agent.llm.callbacks[0] == dummy_callback
|
||||
@@ -1242,14 +1225,11 @@ def test_llm_call_with_error():
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_handle_context_length_exceeds_limit():
|
||||
# Import necessary modules
|
||||
from crewai.utilities.agent_utils import handle_context_length
|
||||
from crewai_core.printer import Printer
|
||||
|
||||
# Create mocks for dependencies
|
||||
printer = Printer()
|
||||
|
||||
# Create an agent just for its LLM
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
@@ -1259,7 +1239,6 @@ def test_handle_context_length_exceeds_limit():
|
||||
|
||||
llm = agent.llm
|
||||
|
||||
# Create test messages
|
||||
messages = [
|
||||
{
|
||||
"role": "user",
|
||||
@@ -1267,11 +1246,9 @@ def test_handle_context_length_exceeds_limit():
|
||||
}
|
||||
]
|
||||
|
||||
# Set up test parameters
|
||||
respect_context_window = True
|
||||
callbacks = []
|
||||
|
||||
# Apply our patch to summarize_messages to force an error
|
||||
with patch("crewai.utilities.agent_utils.summarize_messages") as mock_summarize:
|
||||
mock_summarize.side_effect = ValueError("Context length limit exceeded")
|
||||
|
||||
@@ -1285,7 +1262,6 @@ def test_handle_context_length_exceeds_limit():
|
||||
callbacks=callbacks,
|
||||
)
|
||||
|
||||
# Verify our patch was called and raised the correct error
|
||||
assert "Context length limit exceeded" in str(excinfo.value)
|
||||
mock_summarize.assert_called_once()
|
||||
|
||||
@@ -1353,19 +1329,16 @@ def test_agent_with_all_llm_attributes():
|
||||
assert agent.llm.timeout == 10
|
||||
assert agent.llm.temperature == 0.7
|
||||
assert agent.llm.top_p == 0.9
|
||||
# assert agent.llm.n == 1
|
||||
assert set(agent.llm.stop) == set(["STOP", "END"])
|
||||
assert all(word in agent.llm.stop for word in ["STOP", "END"])
|
||||
assert agent.llm.max_tokens == 100
|
||||
assert agent.llm.presence_penalty == 0.1
|
||||
assert agent.llm.frequency_penalty == 0.1
|
||||
# assert agent.llm.logit_bias == {50256: -100}
|
||||
assert agent.llm.response_format == {"type": "json_object"}
|
||||
assert agent.llm.seed == 42
|
||||
assert agent.llm.logprobs
|
||||
assert agent.llm.top_logprobs == 5
|
||||
assert agent.llm.base_url == "https://api.openai.com/v1"
|
||||
# assert agent.llm.api_version == "2023-05-15"
|
||||
assert agent.llm.api_key == "sk-your-api-key-here"
|
||||
|
||||
|
||||
@@ -1807,7 +1780,6 @@ def test_agent_with_knowledge_sources_generate_search_query():
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# Updated assertion to check the JSON content
|
||||
assert "Brandon" in str(agent.knowledge_search_query)
|
||||
assert "favorite color" in str(agent.knowledge_search_query)
|
||||
|
||||
@@ -1830,7 +1802,6 @@ def test_agent_with_knowledge_with_no_crewai_knowledge():
|
||||
knowledge=mock_knowledge,
|
||||
)
|
||||
|
||||
# Create a task that requires the agent to use the knowledge
|
||||
task = Task(
|
||||
description="What is Vidit's favorite color?",
|
||||
expected_output="Vidit's favorclearite color.",
|
||||
@@ -1855,7 +1826,6 @@ def test_agent_with_only_crewai_knowledge():
|
||||
),
|
||||
)
|
||||
|
||||
# Create a task that requires the agent to use the knowledge
|
||||
task = Task(
|
||||
description="What is Vidit's favorite color?",
|
||||
expected_output="Vidit's favorite color.",
|
||||
@@ -1884,7 +1854,6 @@ def test_agent_knowledege_with_crewai_knowledge():
|
||||
knowledge=agent_knowledge,
|
||||
)
|
||||
|
||||
# Create a task that requires the agent to use the knowledge
|
||||
task = Task(
|
||||
description="What is Vidit's favorite color?",
|
||||
expected_output="Vidit's favorclearite color.",
|
||||
@@ -1902,23 +1871,20 @@ def test_litellm_auth_error_handling():
|
||||
"""Test that LiteLLM authentication errors are handled correctly and not retried."""
|
||||
from litellm import AuthenticationError as LiteLLMAuthenticationError
|
||||
|
||||
# Create an agent with a mocked LLM and max_retry_limit=0
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=LLM(model="gpt-4", is_litellm=True),
|
||||
max_retry_limit=0, # Disable retries for authentication errors
|
||||
max_retry_limit=0,
|
||||
)
|
||||
|
||||
# Create a task
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Mock the LLM call to raise AuthenticationError
|
||||
with (
|
||||
patch.object(LLM, "call") as mock_llm_call,
|
||||
pytest.raises(LiteLLMAuthenticationError, match="Invalid API key"),
|
||||
@@ -1928,7 +1894,6 @@ def test_litellm_auth_error_handling():
|
||||
)
|
||||
agent.execute_task(task)
|
||||
|
||||
# Verify the call was only made once (no retries)
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
@@ -1937,7 +1902,6 @@ def test_crew_agent_executor_litellm_auth_error():
|
||||
from crewai.agents.tools_handler import ToolsHandler
|
||||
from litellm.exceptions import AuthenticationError
|
||||
|
||||
# Create an agent and executor
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
@@ -1950,7 +1914,6 @@ def test_crew_agent_executor_litellm_auth_error():
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Create executor with all required parameters
|
||||
executor = CrewAgentExecutor(
|
||||
agent=agent,
|
||||
task=task,
|
||||
@@ -1965,7 +1928,6 @@ def test_crew_agent_executor_litellm_auth_error():
|
||||
tools_handler=ToolsHandler(),
|
||||
)
|
||||
|
||||
# Mock the LLM call to raise AuthenticationError
|
||||
with (
|
||||
patch.object(LLM, "call") as mock_llm_call,
|
||||
pytest.raises(AuthenticationError) as exc_info,
|
||||
@@ -1981,10 +1943,8 @@ def test_crew_agent_executor_litellm_auth_error():
|
||||
}
|
||||
)
|
||||
|
||||
# Verify the call was only made once (no retries)
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
# Assert that the exception was raised and has the expected attributes
|
||||
assert exc_info.type is AuthenticationError
|
||||
assert "Invalid API key".lower() in exc_info.value.message.lower()
|
||||
assert exc_info.value.llm_provider == "openai"
|
||||
@@ -2004,14 +1964,12 @@ def test_litellm_anthropic_error_handling():
|
||||
max_retry_limit=0,
|
||||
)
|
||||
|
||||
# Create a task
|
||||
task = Task(
|
||||
description="Test task",
|
||||
expected_output="Test output",
|
||||
agent=agent,
|
||||
)
|
||||
|
||||
# Mock the LLM call to raise AnthropicError
|
||||
with (
|
||||
patch.object(LLM, "call") as mock_llm_call,
|
||||
pytest.raises(AnthropicError, match="Test Anthropic error"),
|
||||
@@ -2022,7 +1980,6 @@ def test_litellm_anthropic_error_handling():
|
||||
)
|
||||
agent.execute_task(task)
|
||||
|
||||
# Verify the LLM call was only made once (no retries)
|
||||
mock_llm_call.assert_called_once()
|
||||
|
||||
|
||||
|
||||
@@ -75,7 +75,7 @@ class TestAgentA2AKickoff:
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
assert isinstance(result.raw, str)
|
||||
assert len(result.raw) > 50 # Should have a meaningful response
|
||||
assert len(result.raw) > 50
|
||||
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_returns_lite_agent_output(
|
||||
@@ -99,14 +99,12 @@ class TestAgentA2AKickoff:
|
||||
self, researcher_agent: Agent
|
||||
) -> None:
|
||||
"""Test that agent handles multi-turn A2A conversations."""
|
||||
# This should trigger multiple turns of conversation
|
||||
result = researcher_agent.kickoff(
|
||||
"Ask the remote A2A agent about recent developments in AI agent communication protocols."
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
# The response should contain information about A2A or agent protocols
|
||||
assert isinstance(result.raw, str)
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@@ -119,7 +117,6 @@ class TestAgentA2AKickoff:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# This should work without A2A delegation
|
||||
result = agent.kickoff("Say hello")
|
||||
|
||||
assert result is not None
|
||||
|
||||
@@ -285,7 +285,6 @@ class TestAgentExecutor:
|
||||
|
||||
result = executor.finalize()
|
||||
|
||||
# Should return "skipped" and not set is_finished
|
||||
assert result == "skipped"
|
||||
assert executor.state.is_finished is False
|
||||
|
||||
@@ -420,7 +419,6 @@ class TestAgentExecutor:
|
||||
mock_dependencies["step_callback"] = None
|
||||
executor = _build_executor(**mock_dependencies)
|
||||
|
||||
# Should not raise error
|
||||
executor._invoke_step_callback(
|
||||
AgentFinish(thought="thinking", output="test", text="final")
|
||||
)
|
||||
@@ -738,7 +736,6 @@ class TestFlowInvoke:
|
||||
"""Test successful invoke without human feedback."""
|
||||
executor = _build_executor(**mock_dependencies)
|
||||
|
||||
# Mock kickoff to set the final answer in state
|
||||
def mock_kickoff_side_effect():
|
||||
executor.state.current_answer = AgentFinish(
|
||||
thought="final thinking", output="Final result", text="complete"
|
||||
@@ -981,7 +978,6 @@ class TestNativeToolExecution:
|
||||
executor.state.todos = TodoList(items=[])
|
||||
assert executor.check_native_todo_completion() == "todo_not_satisfied"
|
||||
|
||||
# With a current todo that has tool_to_use → satisfied
|
||||
running = TodoItem(
|
||||
step_number=1,
|
||||
description="Use the expected tool",
|
||||
@@ -991,7 +987,6 @@ class TestNativeToolExecution:
|
||||
executor.state.todos = TodoList(items=[running])
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
# With a current todo without tool_to_use → still satisfied
|
||||
running.tool_to_use = None
|
||||
assert executor.check_native_todo_completion() == "todo_satisfied"
|
||||
|
||||
@@ -1063,10 +1058,8 @@ class TestAgentExecutorPlanning:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Execute kickoff with a simple task
|
||||
result = agent.kickoff("What is 2 + 2?")
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
assert "4" in str(result)
|
||||
|
||||
@@ -1087,10 +1080,8 @@ class TestAgentExecutorPlanning:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Execute kickoff
|
||||
result = agent.kickoff("What is 3 + 3?")
|
||||
|
||||
# Verify we get a result
|
||||
assert result is not None
|
||||
assert "6" in str(result)
|
||||
|
||||
@@ -1107,13 +1098,12 @@ class TestAgentExecutorPlanning:
|
||||
goal="Help solve simple math problems",
|
||||
backstory="A helpful assistant",
|
||||
llm=llm,
|
||||
planning=False, # Explicitly disable planning
|
||||
planning=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
result = agent.kickoff("What is 5 + 5?")
|
||||
|
||||
# Should still complete successfully
|
||||
assert result is not None
|
||||
assert "10" in str(result)
|
||||
|
||||
@@ -1136,7 +1126,6 @@ class TestAgentExecutorPlanning:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Should have planning_config created from reasoning=True
|
||||
assert agent.planning_config is not None
|
||||
assert agent.planning_enabled is True
|
||||
|
||||
@@ -1158,7 +1147,6 @@ class TestAgentExecutorPlanning:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Track executor for inspection
|
||||
executor_ref = [None]
|
||||
original_invoke = AgentExecutor.invoke
|
||||
|
||||
@@ -1169,10 +1157,8 @@ class TestAgentExecutorPlanning:
|
||||
with patch.object(AgentExecutor, "invoke", capture_executor):
|
||||
result = agent.kickoff("What is 7 + 7?")
|
||||
|
||||
# Verify result
|
||||
assert result is not None
|
||||
|
||||
# If we captured an executor, check its state
|
||||
if executor_ref[0] is not None:
|
||||
# After planning, state should have plan info
|
||||
assert hasattr(executor_ref[0].state, "plan")
|
||||
@@ -1204,7 +1190,6 @@ class TestAgentExecutorPlanning:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Track the plan that gets generated
|
||||
captured_plan = [None]
|
||||
original_invoke = AgentExecutor.invoke
|
||||
|
||||
@@ -1219,13 +1204,10 @@ class TestAgentExecutorPlanning:
|
||||
"Show your work for each step."
|
||||
)
|
||||
|
||||
# Verify we got a result with step outputs
|
||||
assert result is not None
|
||||
result_str = str(result)
|
||||
# Should contain at least some mathematical content from the steps
|
||||
assert "prime" in result_str.lower() or "2" in result_str or "10" in result_str
|
||||
|
||||
# Verify a plan was generated
|
||||
assert captured_plan[0] is not None
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@@ -1268,7 +1250,6 @@ class TestAgentExecutorPlanning:
|
||||
|
||||
assert result is not None
|
||||
result_str = str(result)
|
||||
# Should contain conversion-related content
|
||||
assert "212" in result_str or "210" in result_str or "Fahrenheit" in result_str or "celsius" in result_str.lower()
|
||||
|
||||
# Plan should exist
|
||||
@@ -1357,10 +1338,8 @@ class TestResponseFormatWithKickoff:
|
||||
)
|
||||
|
||||
assert result is not None
|
||||
# The synthesis step should have produced structured output
|
||||
assert result.pydantic is not None
|
||||
assert isinstance(result.pydantic, ResearchSummary)
|
||||
# Verify the structured fields are populated
|
||||
assert len(result.pydantic.topic) > 0
|
||||
assert len(result.pydantic.key_findings) >= 1
|
||||
assert len(result.pydantic.conclusion) > 0
|
||||
@@ -1498,7 +1477,6 @@ class TestReasoningEffort:
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Capture the executor to inspect state after execution
|
||||
executor_ref = [None]
|
||||
original_invoke = AgentExecutor.invoke
|
||||
|
||||
@@ -1515,19 +1493,16 @@ class TestReasoningEffort:
|
||||
assert result is not None
|
||||
assert "10" in str(result)
|
||||
|
||||
# Verify observations were still collected (heuristic path, no LLM)
|
||||
executor = executor_ref[0]
|
||||
if executor is not None and executor.state.todos.items:
|
||||
assert len(executor.state.observations) > 0, (
|
||||
"Low effort should still record heuristic observations"
|
||||
)
|
||||
|
||||
# Verify no replan was triggered
|
||||
assert executor.state.replan_count == 0, (
|
||||
"Low effort should never trigger replanning"
|
||||
)
|
||||
|
||||
# Check execution log for reasoning_effort annotation
|
||||
observation_logs = [
|
||||
log for log in executor.state.execution_log
|
||||
if log.get("type") == "observation"
|
||||
@@ -1581,14 +1556,12 @@ class TestReasoningEffort:
|
||||
assert result is not None
|
||||
assert "10" in str(result)
|
||||
|
||||
# Verify observations were collected
|
||||
executor = executor_ref[0]
|
||||
if executor is not None and executor.state.todos.items:
|
||||
assert len(executor.state.observations) > 0, (
|
||||
"High effort should run observe() on every step"
|
||||
)
|
||||
|
||||
# Check execution log shows high reasoning_effort
|
||||
observation_logs = [
|
||||
log for log in executor.state.execution_log
|
||||
if log.get("type") == "observation"
|
||||
@@ -1610,7 +1583,6 @@ class TestReasoningEffort:
|
||||
TodoList,
|
||||
)
|
||||
|
||||
# --- Build a minimal mock executor with medium effort ---
|
||||
executor = Mock(spec=AgentExecutor)
|
||||
executor.agent = Mock()
|
||||
executor.agent.verbose = False
|
||||
@@ -1622,7 +1594,6 @@ class TestReasoningEffort:
|
||||
AgentExecutor.handle_step_observed_medium.__get__(executor)
|
||||
)
|
||||
|
||||
# --- Case 1: step succeeded → should return "continue_plan" ---
|
||||
success_todo = TodoItem(
|
||||
step_number=1,
|
||||
description="Calculate something",
|
||||
@@ -1635,7 +1606,6 @@ class TestReasoningEffort:
|
||||
remaining_plan_still_valid=True,
|
||||
)
|
||||
|
||||
# Set up state
|
||||
todo_list = TodoList(items=[success_todo])
|
||||
executor.state = Mock()
|
||||
executor.state.todos = todo_list
|
||||
@@ -1647,7 +1617,6 @@ class TestReasoningEffort:
|
||||
)
|
||||
assert success_todo.status == "completed"
|
||||
|
||||
# --- Case 2: step failed → should return "replan_now" ---
|
||||
failed_todo = TodoItem(
|
||||
step_number=2,
|
||||
description="Divide by zero",
|
||||
@@ -1687,7 +1656,6 @@ class TestReasoningEffort:
|
||||
executor.agent.planning_config = Mock()
|
||||
executor.agent.planning_config.reasoning_effort = "low"
|
||||
|
||||
# Bind the real method
|
||||
executor.handle_step_observed_low = (
|
||||
AgentExecutor.handle_step_observed_low.__get__(executor)
|
||||
)
|
||||
@@ -1764,7 +1732,6 @@ class TestReasoningEffort:
|
||||
with pytest.raises(ValidationError):
|
||||
PlanningConfig(reasoning_effort="ultra")
|
||||
|
||||
# Valid values should work
|
||||
for level in ("low", "medium", "high"):
|
||||
config = PlanningConfig(reasoning_effort=level)
|
||||
assert config.reasoning_effort == level
|
||||
@@ -1926,9 +1893,7 @@ class TestObserverResponseParsing:
|
||||
assert observation.replan_reason == "build system is misconfigured"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Max Iterations Routing
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestMaxIterationsRouting:
|
||||
@@ -1966,9 +1931,7 @@ class TestMaxIterationsRouting:
|
||||
assert result == "continue_reasoning_native"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Native Tool Call Edge Cases
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallMaxUsage:
|
||||
@@ -1984,9 +1947,7 @@ class TestNativeToolCallMaxUsage:
|
||||
assert 'result = f"Tool \'{func_name}\' has reached its maximum usage limit' in source
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Executor State Reset on Re-invoke
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestExecutorStateReset:
|
||||
@@ -2016,9 +1977,7 @@ class TestExecutorStateReset:
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Plan Generation Isolation
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestPlanGenerationIsolation:
|
||||
@@ -2038,9 +1997,7 @@ class TestPlanGenerationIsolation:
|
||||
)
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Todo Status Tracking
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestTodoStatusTracking:
|
||||
@@ -2081,9 +2038,7 @@ class TestTodoStatusTracking:
|
||||
assert len(completed) == 0
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# TodoList Result Handling
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestTodoResultHandling:
|
||||
@@ -2123,9 +2078,7 @@ class TestTodoResultHandling:
|
||||
assert item.result == "existing", "None result should not overwrite existing"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Dependency Resolution with Failed Steps
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestDependencyResolutionWithFailures:
|
||||
@@ -2169,9 +2122,7 @@ class TestDependencyResolutionWithFailures:
|
||||
assert len(ready) == 1, "Downstream todo should be ready when dep is failed"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# PlanningConfig Defaults
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestPlanningConfigDefaults:
|
||||
@@ -2195,9 +2146,7 @@ class TestPlanningConfigDefaults:
|
||||
assert config.reasoning_effort == "medium"
|
||||
|
||||
|
||||
# =========================================================================
|
||||
# Vision Image Format Contract
|
||||
# =========================================================================
|
||||
|
||||
|
||||
class TestVisionImageFormatContract:
|
||||
|
||||
@@ -8,9 +8,6 @@ from crewai import Agent, PlanningConfig, Task
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for PlanningConfig configuration (no LLM calls needed)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_planning_config_default_values():
|
||||
@@ -66,7 +63,6 @@ def test_agent_with_planning_config_custom_prompts():
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Just test that the agent is created properly
|
||||
assert agent.planning_config is not None
|
||||
assert agent.planning_config.system_prompt == custom_system_prompt
|
||||
assert agent.planning_config.plan_prompt == custom_plan_prompt
|
||||
@@ -116,7 +112,6 @@ def test_planning_enabled_property():
|
||||
"""Test the planning_enabled property on Agent."""
|
||||
llm = LLM("gpt-4o-mini")
|
||||
|
||||
# With planning_config enabled
|
||||
agent_with_planning = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test",
|
||||
@@ -126,7 +121,6 @@ def test_planning_enabled_property():
|
||||
)
|
||||
assert agent_with_planning.planning_enabled is True
|
||||
|
||||
# With planning_config disabled
|
||||
agent_disabled = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test",
|
||||
@@ -136,7 +130,6 @@ def test_planning_enabled_property():
|
||||
)
|
||||
assert agent_disabled.planning_enabled is False
|
||||
|
||||
# Without planning_config
|
||||
agent_no_planning = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test",
|
||||
@@ -146,16 +139,13 @@ def test_planning_enabled_property():
|
||||
assert agent_no_planning.planning_enabled is False
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for backward compatibility with reasoning=True (no LLM calls)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
def test_agent_with_reasoning_backward_compat():
|
||||
"""Test agent with reasoning=True (backward compatibility)."""
|
||||
llm = LLM("gpt-4o-mini")
|
||||
|
||||
# This should emit a deprecation warning
|
||||
with warnings.catch_warnings(record=True):
|
||||
warnings.simplefilter("always")
|
||||
agent = Agent(
|
||||
@@ -167,7 +157,6 @@ def test_agent_with_reasoning_backward_compat():
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Should have created a PlanningConfig internally
|
||||
assert agent.planning_config is not None
|
||||
assert agent.planning_enabled is True
|
||||
|
||||
@@ -186,14 +175,10 @@ def test_agent_with_reasoning_and_max_attempts_backward_compat():
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# Should have created a PlanningConfig with max_attempts
|
||||
assert agent.planning_config is not None
|
||||
assert agent.planning_config.max_attempts == 5
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for Agent.kickoff() with planning (uses AgentExecutor)
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
@@ -246,7 +231,7 @@ def test_agent_kickoff_with_planning_disabled():
|
||||
goal="Help solve math problems",
|
||||
backstory="A helpful assistant",
|
||||
llm=llm,
|
||||
planning=False, # Explicitly disable planning
|
||||
planning=False,
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
@@ -280,10 +265,6 @@ def test_agent_kickoff_multi_step_task_with_planning():
|
||||
assert "20" in str(result)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Tests for Agent.execute_task() with planning (uses CrewAgentExecutor)
|
||||
# These test the legacy path via handle_reasoning()
|
||||
# =============================================================================
|
||||
|
||||
|
||||
@pytest.mark.vcr()
|
||||
|
||||
@@ -394,7 +394,6 @@ class TestInvokeStepCallback:
|
||||
executor.step_callback = None
|
||||
answer = AgentFinish(thought="thinking", output="test", text="final")
|
||||
|
||||
# Should not raise
|
||||
executor._invoke_step_callback(answer)
|
||||
|
||||
|
||||
|
||||
@@ -231,7 +231,7 @@ def test_safe_repair_json():
|
||||
def test_safe_repair_json_unrepairable():
|
||||
invalid_json = "{invalid_json"
|
||||
result = parser._safe_repair_json(invalid_json)
|
||||
assert result == invalid_json # Should return the original if unrepairable
|
||||
assert result == invalid_json
|
||||
|
||||
|
||||
def test_safe_repair_json_missing_quotes():
|
||||
|
||||
@@ -19,7 +19,6 @@ from crewai.tools import BaseTool
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
|
||||
|
||||
# A simple test tool
|
||||
class SecretLookupTool(BaseTool):
|
||||
name: str = "secret_lookup"
|
||||
description: str = "A tool to lookup secrets"
|
||||
@@ -28,7 +27,6 @@ class SecretLookupTool(BaseTool):
|
||||
return "SUPERSECRETPASSWORD123"
|
||||
|
||||
|
||||
# Define Mock Search Tool
|
||||
class WebSearchTool(BaseTool):
|
||||
"""Tool for searching the web for information."""
|
||||
|
||||
@@ -37,7 +35,6 @@ class WebSearchTool(BaseTool):
|
||||
|
||||
def _run(self, query: str) -> str:
|
||||
"""Search the web for information about a topic."""
|
||||
# This is a mock implementation
|
||||
if "tokyo" in query.lower():
|
||||
return "Tokyo's population in 2023 was approximately 21 million people in the city proper, and 37 million in the greater metropolitan area."
|
||||
if "climate change" in query.lower() and "coral" in query.lower():
|
||||
@@ -45,7 +42,6 @@ class WebSearchTool(BaseTool):
|
||||
return f"Found information about {query}: This is a simulated search result for demonstration purposes."
|
||||
|
||||
|
||||
# Define Mock Calculator Tool
|
||||
class CalculatorTool(BaseTool):
|
||||
"""Tool for performing calculations."""
|
||||
|
||||
@@ -55,7 +51,6 @@ class CalculatorTool(BaseTool):
|
||||
def _run(self, expression: str) -> str:
|
||||
"""Calculate the result of a mathematical expression."""
|
||||
try:
|
||||
# Using eval with restricted builtins for test purposes only
|
||||
result = eval(expression, {"__builtins__": {}}) # noqa: S307
|
||||
return f"The result of {expression} is {result}"
|
||||
except Exception as e:
|
||||
@@ -75,7 +70,6 @@ class ResearchResult(BaseModel):
|
||||
@pytest.mark.parametrize("verbose", [True, False])
|
||||
def test_agent_kickoff_preserves_parameters(verbose):
|
||||
"""Test that Agent.kickoff() uses the correct parameters from the Agent."""
|
||||
# Create a test agent with specific parameters
|
||||
mock_llm = Mock(spec=LLM)
|
||||
mock_llm.call.return_value = "Final Answer: Test response"
|
||||
mock_llm.stop = []
|
||||
@@ -104,10 +98,8 @@ def test_agent_kickoff_preserves_parameters(verbose):
|
||||
verbose=verbose,
|
||||
)
|
||||
|
||||
# Call kickoff and verify it works
|
||||
result = agent.kickoff("Test query")
|
||||
|
||||
# Verify the agent was configured correctly
|
||||
assert agent.role == "Test Agent"
|
||||
assert agent.goal == "Test Goal"
|
||||
assert agent.backstory == "Test Backstory"
|
||||
@@ -117,7 +109,6 @@ def test_agent_kickoff_preserves_parameters(verbose):
|
||||
assert agent.max_iter == max_iter
|
||||
assert agent.verbose == verbose
|
||||
|
||||
# Verify kickoff returned a result
|
||||
assert result is not None
|
||||
assert result.raw is not None
|
||||
|
||||
@@ -125,7 +116,6 @@ def test_agent_kickoff_preserves_parameters(verbose):
|
||||
@pytest.mark.vcr()
|
||||
def test_lite_agent_with_tools():
|
||||
"""Test that Agent can use tools."""
|
||||
# Create a LiteAgent with tools
|
||||
llm = LLM(model="gpt-4o-mini")
|
||||
agent = Agent(
|
||||
role="Research Assistant",
|
||||
@@ -157,7 +147,6 @@ def test_lite_agent_with_tools():
|
||||
|
||||
agent.kickoff("What are the effects of climate change on coral reefs?")
|
||||
|
||||
# Verify tool usage events were emitted
|
||||
assert event_received.wait(timeout=5), "Timeout waiting for tool usage events"
|
||||
assert len(received_events) > 0, "Tool usage events should be emitted"
|
||||
event = received_events[0]
|
||||
@@ -269,7 +258,6 @@ async def test_lite_agent_returns_usage_metrics_async():
|
||||
"What is the population of Tokyo? Return your structured output in JSON format with the following fields: summary, confidence"
|
||||
)
|
||||
assert isinstance(result, LiteAgentOutput)
|
||||
# Check for population data in various formats (text or numeric)
|
||||
assert (
|
||||
"21 million" in result.raw
|
||||
or "37 million" in result.raw
|
||||
@@ -651,7 +639,6 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post):
|
||||
}
|
||||
mock_get.return_value = mock_response
|
||||
|
||||
# Mock the platform tool execution
|
||||
mock_post_response = Mock()
|
||||
mock_post_response.ok = True
|
||||
mock_post_response.json.return_value = {
|
||||
@@ -680,7 +667,6 @@ def test_agent_kickoff_with_platform_tools(mock_get, mock_post):
|
||||
@pytest.mark.vcr()
|
||||
def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
"""Test that Agent.kickoff() properly integrates MCP tools with LiteAgent"""
|
||||
# Setup mock MCP tools - create a proper BaseTool instance
|
||||
class MockMCPTool(BaseTool):
|
||||
name: str = "exa_search"
|
||||
description: str = "Search the web using Exa"
|
||||
@@ -690,7 +676,6 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
|
||||
mock_get_mcp_tools.return_value = [MockMCPTool()]
|
||||
|
||||
# Create agent with MCP servers
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
@@ -700,20 +685,14 @@ def test_agent_kickoff_with_mcp_tools(mock_get_mcp_tools):
|
||||
verbose=True
|
||||
)
|
||||
|
||||
# Execute kickoff
|
||||
result = agent.kickoff("Search for information about AI")
|
||||
|
||||
# Verify the result is a LiteAgentOutput
|
||||
assert isinstance(result, LiteAgentOutput)
|
||||
assert result.raw is not None
|
||||
|
||||
# Verify MCP tools were retrieved
|
||||
mock_get_mcp_tools.assert_called_once_with(["https://mcp.exa.ai/mcp?api_key=test_exa_key&profile=research"])
|
||||
|
||||
|
||||
# ============================================================================
|
||||
# Tests for LiteAgent inside Flow (magic auto-async pattern)
|
||||
# ============================================================================
|
||||
|
||||
from crewai.flow.flow import listen
|
||||
|
||||
@@ -726,7 +705,6 @@ def test_lite_agent_inside_flow_sync():
|
||||
from within a Flow automatically detects the event loop and returns a
|
||||
coroutine that the Flow framework awaits. Users don't need to use async/await.
|
||||
"""
|
||||
# Track execution
|
||||
execution_log = []
|
||||
|
||||
class TestFlow(Flow):
|
||||
@@ -748,7 +726,6 @@ def test_lite_agent_inside_flow_sync():
|
||||
flow = TestFlow()
|
||||
result = flow.kickoff()
|
||||
|
||||
# Verify the flow executed successfully
|
||||
assert "flow_started" in execution_log
|
||||
assert "agent_completed" in execution_log
|
||||
assert result is not None
|
||||
@@ -851,7 +828,6 @@ def test_lite_agent_standalone_still_works():
|
||||
verbose=False,
|
||||
)
|
||||
|
||||
# This should work normally - no Flow, no event loop
|
||||
result = agent.kickoff(messages="What is 5+5? Reply with just the number.")
|
||||
|
||||
assert result is not None
|
||||
@@ -1031,7 +1007,7 @@ def test_prepare_kickoff_param_files_override_message_files():
|
||||
)
|
||||
|
||||
assert "files" in inputs
|
||||
assert inputs["files"]["same.png"] is param_file # param takes precedence
|
||||
assert inputs["files"]["same.png"] is param_file
|
||||
|
||||
|
||||
def test_lite_agent_verbose_false_suppresses_printer_output():
|
||||
@@ -1066,11 +1042,9 @@ def test_lite_agent_verbose_false_suppresses_printer_output():
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, LiteAgentOutput)
|
||||
# Verify the printer was never called when verbose=False
|
||||
mock_printer.print.assert_not_called()
|
||||
|
||||
|
||||
# --- LiteAgent memory integration ---
|
||||
|
||||
|
||||
@pytest.mark.filterwarnings("ignore:LiteAgent is deprecated")
|
||||
|
||||
@@ -215,9 +215,7 @@ def _attach_parallel_probe_handler() -> None:
|
||||
event.finished_at.timestamp(),
|
||||
)
|
||||
|
||||
# =============================================================================
|
||||
# OpenAI Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestOpenAINativeToolCalling:
|
||||
@@ -448,9 +446,7 @@ class TestOpenAINativeToolCalling:
|
||||
unregister_after_tool_call_hook(after_hook)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Anthropic Provider Tests
|
||||
# =============================================================================
|
||||
class TestAnthropicNativeToolCalling:
|
||||
"""Tests for native tool calling with Anthropic models."""
|
||||
|
||||
@@ -559,9 +555,7 @@ class TestAnthropicNativeToolCalling:
|
||||
_assert_tools_overlapped()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Google/Gemini Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestGeminiNativeToolCalling:
|
||||
@@ -672,9 +666,7 @@ class TestGeminiNativeToolCalling:
|
||||
_assert_tools_overlapped()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Azure Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestAzureNativeToolCalling:
|
||||
@@ -688,7 +680,6 @@ class TestAzureNativeToolCalling:
|
||||
"AZURE_API_BASE": "https://test.openai.azure.com",
|
||||
"AZURE_API_VERSION": "2024-02-15-preview",
|
||||
}
|
||||
# Only patch if keys are not already in environment
|
||||
if "AZURE_API_KEY" not in os.environ:
|
||||
with patch.dict(os.environ, env_vars):
|
||||
yield
|
||||
@@ -796,9 +787,7 @@ class TestAzureNativeToolCalling:
|
||||
_assert_tools_overlapped()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Bedrock Provider Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestBedrockNativeToolCalling:
|
||||
@@ -901,9 +890,7 @@ class TestBedrockNativeToolCalling:
|
||||
_assert_tools_overlapped()
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Cross-Provider Native Tool Calling Behavior Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingBehavior:
|
||||
@@ -930,9 +917,7 @@ class TestNativeToolCallingBehavior:
|
||||
assert llm.supports_function_calling() is True
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Token Usage Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingTokenUsage:
|
||||
@@ -1000,20 +985,16 @@ def test_native_tool_calling_error_handling(failing_tool: FailingTool):
|
||||
result = agent.kickoff("Use the failing_tool to do something.")
|
||||
assert result is not None
|
||||
|
||||
# Verify error event was emitted
|
||||
assert event_received.wait(timeout=10), "ToolUsageErrorEvent was not emitted"
|
||||
assert len(received_events) >= 1
|
||||
|
||||
# Verify event attributes
|
||||
error_event = received_events[0]
|
||||
assert error_event.tool_name == "failing_tool"
|
||||
assert error_event.agent_role == agent.role
|
||||
assert "This tool always fails" in str(error_event.error)
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# Max Usage Count Tests for Native Tool Calling
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class CountingInput(BaseModel):
|
||||
@@ -1042,7 +1023,6 @@ class TestMaxUsageCountWithNativeToolCalling:
|
||||
"""Test that max_usage_count is properly tracked when using native tool calling."""
|
||||
tool = CountingTool(max_usage_count=3)
|
||||
|
||||
# Verify initial state
|
||||
assert tool.max_usage_count == 3
|
||||
assert tool.current_usage_count == 0
|
||||
|
||||
@@ -1065,7 +1045,6 @@ class TestMaxUsageCountWithNativeToolCalling:
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
crew.kickoff()
|
||||
|
||||
# Verify usage count was tracked
|
||||
assert tool.max_usage_count == 3
|
||||
assert tool.current_usage_count <= tool.max_usage_count
|
||||
|
||||
@@ -1094,7 +1073,6 @@ class TestMaxUsageCountWithNativeToolCalling:
|
||||
crew = Crew(agents=[agent], tasks=[task])
|
||||
result = crew.kickoff()
|
||||
|
||||
# The tool should have been limited to max_usage_count (2) calls
|
||||
assert result is not None
|
||||
assert tool.current_usage_count == tool.max_usage_count
|
||||
# After hitting the limit, further calls should have been rejected
|
||||
@@ -1126,14 +1104,11 @@ class TestMaxUsageCountWithNativeToolCalling:
|
||||
result = crew.kickoff()
|
||||
|
||||
assert result is not None
|
||||
# Verify the requested calls occurred while keeping usage bounded.
|
||||
assert tool.current_usage_count >= 2
|
||||
assert tool.current_usage_count <= tool.max_usage_count
|
||||
|
||||
|
||||
# =============================================================================
|
||||
# JSON Parse Error Handling Tests
|
||||
# =============================================================================
|
||||
|
||||
|
||||
class TestNativeToolCallingJsonParseError:
|
||||
|
||||
@@ -1,9 +1,8 @@
|
||||
import unittest
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import jwt
|
||||
|
||||
from crewai.auth.utils import validate_jwt_token
|
||||
import jwt
|
||||
|
||||
|
||||
@patch("crewai_core.auth.utils.PyJWKClient", return_value=MagicMock())
|
||||
@@ -12,12 +11,11 @@ class TestUtils(unittest.TestCase):
|
||||
def test_validate_jwt_token(self, mock_jwt, mock_pyjwkclient):
|
||||
mock_jwt.decode.return_value = {"exp": 1719859200}
|
||||
|
||||
# Create signing key object mock with a .key attribute
|
||||
mock_pyjwkclient.return_value.get_signing_key_from_jwt.return_value = MagicMock(
|
||||
key="mock_signing_key"
|
||||
)
|
||||
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc" # noqa: S105
|
||||
jwt_token = "aaaaa.bbbbbb.cccccc"
|
||||
|
||||
decoded_token = validate_jwt_token(
|
||||
jwt_token=jwt_token,
|
||||
@@ -48,7 +46,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.ExpiredSignatureError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -58,7 +56,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidAudienceError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -68,7 +66,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidIssuerError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -80,7 +78,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.MissingRequiredClaimError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -90,7 +88,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.exceptions.PyJWKClientError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
@@ -100,7 +98,7 @@ class TestUtils(unittest.TestCase):
|
||||
mock_jwt.decode.side_effect = jwt.InvalidTokenError
|
||||
with self.assertRaises(Exception): # noqa: B017
|
||||
validate_jwt_token(
|
||||
jwt_token="aaaaa.bbbbbb.cccccc", # noqa: S106
|
||||
jwt_token="aaaaa.bbbbbb.cccccc",
|
||||
jwks_url="https://mock_jwks_url",
|
||||
issuer="https://mock_issuer",
|
||||
audience="app_id_xxxx",
|
||||
|
||||
@@ -1,14 +1,13 @@
|
||||
import io
|
||||
import os
|
||||
import zipfile
|
||||
from unittest.mock import MagicMock, patch
|
||||
import zipfile
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from crewai_cli.cli import template_add, template_list
|
||||
from crewai_cli.remote_template.main import TemplateCommand
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -35,7 +34,6 @@ def _make_zipball(files: dict[str, str], top_dir: str = "crewAIInc-template_test
|
||||
return buf.getvalue()
|
||||
|
||||
|
||||
# --- CLI command tests ---
|
||||
|
||||
|
||||
@patch("crewai_cli.cli.TemplateCommand")
|
||||
@@ -73,7 +71,6 @@ def test_template_add_with_output_dir(mock_cls, runner):
|
||||
mock_instance.add_template.assert_called_once_with("deep_research", "my_project")
|
||||
|
||||
|
||||
# --- TemplateCommand unit tests ---
|
||||
|
||||
|
||||
class TestTemplateCommand:
|
||||
@@ -89,7 +86,6 @@ class TestTemplateCommand:
|
||||
mock_response = MagicMock()
|
||||
mock_response.json.return_value = SAMPLE_REPOS
|
||||
mock_response.raise_for_status = MagicMock()
|
||||
# Return empty on page 2 to stop pagination
|
||||
mock_empty = MagicMock()
|
||||
mock_empty.json.return_value = []
|
||||
mock_empty.raise_for_status = MagicMock()
|
||||
@@ -245,7 +241,6 @@ class TestTemplateCommand:
|
||||
|
||||
os.chdir(tmp_path)
|
||||
cmd.add_template("deep_research")
|
||||
# Should return without downloading
|
||||
|
||||
@patch.object(TemplateCommand, "_install_repo")
|
||||
@patch("crewai_cli.remote_template.main.click.prompt", return_value="2")
|
||||
|
||||
@@ -6,10 +6,10 @@ have moved to lib/cli/tests/test_cli.py.
|
||||
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
from crewai_cli.cli import reset_memories
|
||||
from crewai.crew import Crew
|
||||
from crewai_cli.cli import reset_memories
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
@@ -102,7 +102,6 @@ def test_reset_kickoff_outputs(mock_get_crews, runner):
|
||||
|
||||
def test_reset_multiple_legacy_flags_collapsed_to_single_memory_reset(mock_get_crews, runner):
|
||||
result = runner.invoke(reset_memories, ["-s", "-l"])
|
||||
# Both legacy flags collapse to a single --memory reset
|
||||
assert "deprecated" in result.output.lower()
|
||||
call_count = 0
|
||||
for crew in mock_get_crews.return_value:
|
||||
@@ -145,7 +144,6 @@ def test_reset_memory_from_many_crews(mock_get_crews, runner):
|
||||
|
||||
mock_get_crews.return_value = crews
|
||||
|
||||
# Run the command
|
||||
result = runner.invoke(reset_memories, ["--knowledge"])
|
||||
|
||||
call_count = 0
|
||||
|
||||
@@ -1,16 +1,14 @@
|
||||
"""Tests for TokenManager with atomic file operations."""
|
||||
|
||||
from datetime import datetime, timedelta
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import unittest
|
||||
from datetime import datetime, timedelta
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
from crewai_core.token_manager import TokenManager
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
class TestTokenManager(unittest.TestCase):
|
||||
@@ -147,7 +145,6 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
self.temp_dir = tempfile.mkdtemp()
|
||||
self.original_get_path = TokenManager._get_secure_storage_path
|
||||
|
||||
# Patch to use temp directory
|
||||
def mock_get_path() -> Path:
|
||||
return Path(self.temp_dir)
|
||||
|
||||
@@ -183,7 +180,6 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Create file first
|
||||
file_path = Path(self.temp_dir) / "test.txt"
|
||||
file_path.write_bytes(b"original")
|
||||
|
||||
@@ -232,7 +228,6 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
|
||||
tm._atomic_write_secure_file("test.txt", b"content")
|
||||
|
||||
# Check no temp files remain
|
||||
temp_files = list(Path(self.temp_dir).glob(".test.txt.*"))
|
||||
self.assertEqual(len(temp_files), 0)
|
||||
|
||||
@@ -286,9 +281,8 @@ class TestAtomicFileOperations(unittest.TestCase):
|
||||
mock_get_key.return_value = Fernet.generate_key()
|
||||
tm = TokenManager()
|
||||
|
||||
# Should not raise
|
||||
tm._delete_secure_file("nonexistent.txt")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
unittest.main()
|
||||
unittest.main()
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
import os
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
|
||||
import pytest
|
||||
from crewai.utilities import project_utils as utils
|
||||
import pytest
|
||||
|
||||
|
||||
def create_file(path, content):
|
||||
@@ -207,7 +207,6 @@ def temp_crew_project():
|
||||
with open(os.path.join("src", "crew.py"), "w") as f:
|
||||
f.write(crew_content)
|
||||
|
||||
# Create a src/templates directory that should be ignored
|
||||
os.makedirs(os.path.join("src", "templates"), exist_ok=True)
|
||||
with open(os.path.join("src", "templates", "crew.py"), "w") as f:
|
||||
f.write("# This should be ignored")
|
||||
@@ -274,7 +273,6 @@ def test_get_crews_ignores_template_directories(
|
||||
assert not template_crew_detected
|
||||
|
||||
|
||||
# Tests for extract_tools_metadata
|
||||
|
||||
|
||||
def test_extract_tools_metadata_empty_project(temp_project_dir):
|
||||
@@ -433,10 +431,8 @@ __all__ = ['MyTool']
|
||||
assert len(metadata) == 1
|
||||
init_params = metadata[0]["init_params_schema"]
|
||||
assert "properties" in init_params
|
||||
# Custom params should be included
|
||||
assert "api_endpoint" in init_params["properties"]
|
||||
assert "timeout" in init_params["properties"]
|
||||
# Base params should be filtered out
|
||||
assert "name" not in init_params["properties"]
|
||||
assert "description" not in init_params["properties"]
|
||||
|
||||
@@ -467,7 +463,6 @@ __all__ = ['FirstTool', 'SecondTool']
|
||||
|
||||
def test_extract_tools_metadata_multiple_init_files(temp_project_dir):
|
||||
"""Test that extract_tools_metadata extracts metadata from multiple __init__.py files."""
|
||||
# Create tool in root __init__.py
|
||||
create_init_file(
|
||||
temp_project_dir,
|
||||
"""from crewai.tools import BaseTool
|
||||
@@ -480,7 +475,6 @@ __all__ = ['RootTool']
|
||||
""",
|
||||
)
|
||||
|
||||
# Create nested package with another tool
|
||||
nested_dir = temp_project_dir / "nested"
|
||||
nested_dir.mkdir()
|
||||
create_init_file(
|
||||
@@ -537,7 +531,6 @@ class MyTool(BaseTool):
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
@@ -556,6 +549,5 @@ class MyTool(BaseTool):
|
||||
__all__ = ['MyTool']
|
||||
""",
|
||||
)
|
||||
# Should not raise, just return empty list
|
||||
metadata = utils.extract_tools_metadata(dir_path=str(temp_project_dir))
|
||||
assert metadata == []
|
||||
|
||||
@@ -157,7 +157,6 @@ async def test_mixed_handlers_with_dependencies():
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Verify execution order
|
||||
assert execution_order[0] == "setup"
|
||||
assert "finalize" in execution_order
|
||||
assert execution_order.index("finalize") > execution_order.index("sync_process")
|
||||
@@ -187,7 +186,6 @@ async def test_independent_handlers_run_concurrently():
|
||||
if future:
|
||||
await asyncio.wrap_future(future)
|
||||
|
||||
# Both handlers should have executed
|
||||
assert len(execution_order) == 2
|
||||
assert "handler_a" in execution_order
|
||||
assert "handler_b" in execution_order
|
||||
@@ -198,7 +196,6 @@ async def test_circular_dependency_detection():
|
||||
"""Test that circular dependencies are detected and raise an error."""
|
||||
from crewai.events.handler_graph import CircularDependencyError, build_execution_plan
|
||||
|
||||
# Create circular dependency: handler_a -> handler_b -> handler_c -> handler_a
|
||||
def handler_a(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
@@ -208,15 +205,13 @@ async def test_circular_dependency_detection():
|
||||
def handler_c(source, event: DependsTestEvent):
|
||||
pass
|
||||
|
||||
# Build a dependency graph with a cycle
|
||||
handlers = [handler_a, handler_b, handler_c]
|
||||
dependencies = {
|
||||
handler_a: [Depends(handler_b)],
|
||||
handler_b: [Depends(handler_c)],
|
||||
handler_c: [Depends(handler_a)], # Creates the cycle
|
||||
handler_c: [Depends(handler_a)],
|
||||
}
|
||||
|
||||
# Should raise CircularDependencyError about circular dependency
|
||||
with pytest.raises(CircularDependencyError, match="Circular dependency"):
|
||||
build_execution_plan(handlers, dependencies)
|
||||
|
||||
@@ -255,11 +250,9 @@ async def test_depends_equality():
|
||||
dep_a2 = Depends(handler_a)
|
||||
dep_b = Depends(handler_b)
|
||||
|
||||
# Same handler should be equal
|
||||
assert dep_a1 == dep_a2
|
||||
assert hash(dep_a1) == hash(dep_a2)
|
||||
|
||||
# Different handlers should not be equal
|
||||
assert dep_a1 != dep_b
|
||||
assert hash(dep_a1) != hash(dep_b)
|
||||
|
||||
@@ -282,5 +275,4 @@ async def test_aemit_ignores_dependencies():
|
||||
event = DependsTestEvent(value=1)
|
||||
await crewai_event_bus.aemit("test_source", event)
|
||||
|
||||
# Only async handler should execute
|
||||
assert execution_order == ["async_handler"]
|
||||
|
||||
@@ -837,7 +837,6 @@ class TestTriggeredByEventId:
|
||||
assert listener_b_started is not None
|
||||
assert listener_c_started is not None
|
||||
|
||||
# All parallel listeners should point to the same triggering event
|
||||
assert listener_a_started.triggered_by_event_id == trigger_finished.event_id
|
||||
assert listener_b_started.triggered_by_event_id == trigger_finished.event_id
|
||||
assert listener_c_started.triggered_by_event_id == trigger_finished.event_id
|
||||
@@ -995,23 +994,19 @@ class TestTriggeredByEventId:
|
||||
else:
|
||||
second_run_events.append(event)
|
||||
|
||||
# First kickoff
|
||||
capturing_second = False
|
||||
flow1 = ReusableFlow()
|
||||
await flow1.akickoff()
|
||||
crewai_event_bus.flush()
|
||||
|
||||
# Second kickoff
|
||||
capturing_second = True
|
||||
flow2 = ReusableFlow()
|
||||
await flow2.akickoff()
|
||||
crewai_event_bus.flush()
|
||||
|
||||
# Should have events from both runs
|
||||
assert len(first_run_events) >= 4 # 2 started + 2 finished
|
||||
assert len(second_run_events) >= 4
|
||||
|
||||
# Check first run's triggered_by chain
|
||||
first_started = [e for e in first_run_events if isinstance(e, MethodExecutionStartedEvent)]
|
||||
first_finished = [e for e in first_run_events if isinstance(e, MethodExecutionFinishedEvent)]
|
||||
|
||||
@@ -1025,7 +1020,6 @@ class TestTriggeredByEventId:
|
||||
assert first_process_started is not None
|
||||
assert first_process_started.triggered_by_event_id == first_begin_finished.event_id
|
||||
|
||||
# Check second run's triggered_by chain
|
||||
second_started = [e for e in second_run_events if isinstance(e, MethodExecutionStartedEvent)]
|
||||
second_finished = [e for e in second_run_events if isinstance(e, MethodExecutionFinishedEvent)]
|
||||
|
||||
@@ -1039,10 +1033,8 @@ class TestTriggeredByEventId:
|
||||
assert second_process_started is not None
|
||||
assert second_process_started.triggered_by_event_id == second_begin_finished.event_id
|
||||
|
||||
# Verify the two runs have different event_ids (not reusing)
|
||||
assert first_begin_finished.event_id != second_begin_finished.event_id
|
||||
|
||||
# Verify each run has its own independent previous_event_id chain
|
||||
# (chains reset at each top-level execution)
|
||||
first_sorted = sorted(first_run_events, key=lambda e: e.emission_sequence or 0)
|
||||
for event in first_sorted[1:]:
|
||||
@@ -1094,19 +1086,16 @@ class TestTriggeredByEventId:
|
||||
def capture_finished(source, event):
|
||||
events.append(event)
|
||||
|
||||
# Run two flows in parallel
|
||||
flow_a = ParallelTestFlow("flow_a")
|
||||
flow_b = ParallelTestFlow("flow_b")
|
||||
await asyncio.gather(flow_a.akickoff(), flow_b.akickoff())
|
||||
crewai_event_bus.flush()
|
||||
|
||||
# Should have events from both flows (4 events each = 8 total)
|
||||
assert len(events) >= 8
|
||||
|
||||
started_events = [e for e in events if isinstance(e, MethodExecutionStartedEvent)]
|
||||
finished_events = [e for e in events if isinstance(e, MethodExecutionFinishedEvent)]
|
||||
|
||||
# Find flow_a's events by checking the result contains "flow_a"
|
||||
flow_a_begin_finished = [
|
||||
e for e in finished_events
|
||||
if e.method_name == "begin" and "flow_a" in str(e.result)
|
||||
@@ -1124,20 +1113,16 @@ class TestTriggeredByEventId:
|
||||
assert len(flow_a_begin_finished) >= 1
|
||||
assert len(flow_b_begin_finished) >= 1
|
||||
|
||||
# Each flow's process should be triggered by its own begin
|
||||
# Find which process events were triggered by which begin events
|
||||
for process_event in flow_a_process_started:
|
||||
trigger_id = process_event.triggered_by_event_id
|
||||
assert trigger_id is not None
|
||||
|
||||
# The triggering event should be a begin finished event
|
||||
triggering_event = next(
|
||||
(e for e in finished_events if e.event_id == trigger_id), None
|
||||
)
|
||||
assert triggering_event is not None
|
||||
assert triggering_event.method_name == "begin"
|
||||
|
||||
# Verify previous_event_id forms a valid chain across all events
|
||||
all_sorted = sorted(events, key=lambda e: e.emission_sequence or 0)
|
||||
for event in all_sorted[1:]:
|
||||
assert event.previous_event_id is not None
|
||||
@@ -1236,7 +1221,7 @@ class TestTriggeredByEventId:
|
||||
try:
|
||||
await flow.akickoff()
|
||||
except ValueError:
|
||||
pass # Expected
|
||||
pass
|
||||
crewai_event_bus.flush()
|
||||
|
||||
# Even with exception, events should have proper previous_event_id chain
|
||||
@@ -1259,7 +1244,7 @@ class TestTriggeredByEventId:
|
||||
|
||||
class SyncFlow(Flow):
|
||||
@start()
|
||||
def sync_start(self): # Synchronous method
|
||||
def sync_start(self):
|
||||
return "sync_done"
|
||||
|
||||
@listen(sync_start)
|
||||
@@ -1336,7 +1321,6 @@ class TestTriggeredByEventId:
|
||||
assert start_one is not None
|
||||
assert start_two is not None
|
||||
|
||||
# Both start methods should have no triggered_by (they're entry points)
|
||||
assert start_one.triggered_by_event_id is None
|
||||
assert start_two.triggered_by_event_id is None
|
||||
|
||||
@@ -1441,7 +1425,6 @@ class TestTriggeredByEventId:
|
||||
started_events = [e for e in events if isinstance(e, MethodExecutionStartedEvent)]
|
||||
finished_events = [e for e in events if isinstance(e, MethodExecutionFinishedEvent)]
|
||||
|
||||
# Verify each level triggers the next
|
||||
for i in range(5):
|
||||
prev_finished = next(
|
||||
(e for e in finished_events if e.method_name == f"level_{i}"), None
|
||||
@@ -1518,7 +1501,6 @@ class TestTriggeredByEventId:
|
||||
# path_b should NOT be executed since router returned "path_a"
|
||||
assert handle_path_b_started is None
|
||||
|
||||
# The selected path should be triggered by the router
|
||||
assert handle_path_a_started.triggered_by_event_id == router_finished.event_id
|
||||
|
||||
|
||||
@@ -1589,7 +1571,6 @@ class TestCrewEventsInFlowTriggeredBy:
|
||||
# final should be triggered by middle_method
|
||||
assert final_started.triggered_by_event_id == middle_finished.event_id
|
||||
|
||||
# All events should have proper previous_event_id chain
|
||||
all_sorted = sorted(events, key=lambda e: e.emission_sequence or 0)
|
||||
for event in all_sorted[1:]:
|
||||
assert event.previous_event_id is not None
|
||||
@@ -1624,7 +1605,7 @@ class TestCrewEventsInFlowTriggeredBy:
|
||||
events.append(event)
|
||||
|
||||
flow = SyncKickoffFlow()
|
||||
flow.kickoff() # Synchronous kickoff
|
||||
flow.kickoff()
|
||||
crewai_event_bus.flush()
|
||||
|
||||
started_events = [e for e in events if isinstance(e, MethodExecutionStartedEvent)]
|
||||
@@ -1643,7 +1624,6 @@ class TestCrewEventsInFlowTriggeredBy:
|
||||
# Listener should be triggered by start_method
|
||||
assert listener_started.triggered_by_event_id == start_finished.event_id
|
||||
|
||||
# Verify previous_event_id chain
|
||||
all_sorted = sorted(events, key=lambda e: e.emission_sequence or 0)
|
||||
for event in all_sorted[1:]:
|
||||
assert event.previous_event_id is not None
|
||||
|
||||
@@ -14,7 +14,6 @@ def test_get_machine_id_basic():
|
||||
"""Test that _get_machine_id always returns a valid SHA256 hash."""
|
||||
machine_id = _get_machine_id()
|
||||
|
||||
# Should return a 64-character hex string (SHA256)
|
||||
assert isinstance(machine_id, str)
|
||||
assert len(machine_id) == 64
|
||||
assert all(c in "0123456789abcdef" for c in machine_id)
|
||||
@@ -25,7 +24,6 @@ def test_get_machine_id_handles_missing_files():
|
||||
with patch.object(Path, "read_text", side_effect=FileNotFoundError):
|
||||
machine_id = _get_machine_id()
|
||||
|
||||
# Should still return a valid hash even when files are missing
|
||||
assert isinstance(machine_id, str)
|
||||
assert len(machine_id) == 64
|
||||
assert all(c in "0123456789abcdef" for c in machine_id)
|
||||
@@ -36,7 +34,6 @@ def test_get_machine_id_handles_permission_errors():
|
||||
with patch.object(Path, "read_text", side_effect=PermissionError):
|
||||
machine_id = _get_machine_id()
|
||||
|
||||
# Should still return a valid hash even with permission errors
|
||||
assert isinstance(machine_id, str)
|
||||
assert len(machine_id) == 64
|
||||
assert all(c in "0123456789abcdef" for c in machine_id)
|
||||
@@ -47,7 +44,6 @@ def test_get_machine_id_handles_mac_address_failure():
|
||||
with patch("uuid.getnode", side_effect=Exception("MAC address error")):
|
||||
machine_id = _get_machine_id()
|
||||
|
||||
# Should still return a valid hash even without MAC address
|
||||
assert isinstance(machine_id, str)
|
||||
assert len(machine_id) == 64
|
||||
assert all(c in "0123456789abcdef" for c in machine_id)
|
||||
@@ -79,10 +75,8 @@ def test_get_generic_system_id_basic():
|
||||
"""Test that _get_generic_system_id returns reasonable values."""
|
||||
result = _get_generic_system_id()
|
||||
|
||||
# Should return a string or None
|
||||
assert result is None or isinstance(result, str)
|
||||
|
||||
# If it returns a string, it should be non-empty
|
||||
if result:
|
||||
assert len(result) > 0
|
||||
|
||||
@@ -92,7 +86,6 @@ def test_get_generic_system_id_handles_socket_errors():
|
||||
with patch("socket.gethostname", side_effect=Exception("Socket error")):
|
||||
result = _get_generic_system_id()
|
||||
|
||||
# Should still work or return None
|
||||
assert result is None or isinstance(result, str)
|
||||
|
||||
|
||||
@@ -101,7 +94,6 @@ def test_machine_id_consistency():
|
||||
machine_id1 = _get_machine_id()
|
||||
machine_id2 = _get_machine_id()
|
||||
|
||||
# Should be the same across calls (stable fingerprint)
|
||||
assert machine_id1 == machine_id2
|
||||
|
||||
|
||||
|
||||
@@ -77,10 +77,8 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
"""
|
||||
mock_create_llm.return_value = mock_llm
|
||||
|
||||
# Setup execution trace with sufficient LLM calls
|
||||
execution_trace = {"llm_calls": llm_calls}
|
||||
|
||||
# Mock the _detect_loops method to return a simple result
|
||||
evaluator = ReasoningEfficiencyEvaluator(llm=mock_llm)
|
||||
evaluator._detect_loops = MagicMock(return_value=(False, []))
|
||||
|
||||
@@ -99,7 +97,6 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
assert "Reasoning Efficiency Evaluation:" in result.feedback
|
||||
assert "• Focus: 8.0/10" in result.feedback
|
||||
|
||||
# Verify LLM was called
|
||||
mock_llm.call.assert_called_once()
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
@@ -110,10 +107,8 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
mock_llm.call.return_value = "Invalid JSON response"
|
||||
mock_create_llm.return_value = mock_llm
|
||||
|
||||
# Setup execution trace
|
||||
execution_trace = {"llm_calls": llm_calls}
|
||||
|
||||
# Mock the _detect_loops method
|
||||
evaluator = ReasoningEfficiencyEvaluator(llm=mock_llm)
|
||||
evaluator._detect_loops = MagicMock(return_value=(False, []))
|
||||
|
||||
@@ -132,7 +127,6 @@ class TestReasoningEfficiencyEvaluator(BaseEvaluationMetricsTest):
|
||||
|
||||
@patch("crewai.utilities.llm_utils.create_llm")
|
||||
def test_loop_detection(self, mock_create_llm, mock_agent, mock_task, mock_output):
|
||||
# Setup LLM calls with a repeating pattern
|
||||
repetitive_llm_calls = [
|
||||
{
|
||||
"prompt": "How to solve?",
|
||||
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user