Merge branch 'main' into gl/fix/task-planning-order

This commit is contained in:
Greyson LaLonde
2025-12-10 20:34:51 -05:00
committed by GitHub
414 changed files with 33207 additions and 65747 deletions

View File

@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
_suppress_pydantic_deprecation_warnings()
__version__ = "1.6.1"
__version__ = "1.7.0"
_telemetry_submitted = False

View File

@@ -0,0 +1,4 @@
"""A2A Protocol Extensions for CrewAI.
This module contains extensions to the A2A (Agent-to-Agent) protocol.
"""

View File

@@ -0,0 +1,193 @@
"""Base extension interface for A2A wrapper integrations.
This module defines the protocol for extending A2A wrapper functionality
with custom logic for conversation processing, prompt augmentation, and
agent response handling.
"""
from __future__ import annotations
from collections.abc import Sequence
from typing import TYPE_CHECKING, Any, Protocol
if TYPE_CHECKING:
from a2a.types import Message
from crewai.agent.core import Agent
class ConversationState(Protocol):
"""Protocol for extension-specific conversation state.
Extensions can define their own state classes that implement this protocol
to track conversation-specific data extracted from message history.
"""
def is_ready(self) -> bool:
"""Check if the state indicates readiness for some action.
Returns:
True if the state is ready, False otherwise.
"""
...
class A2AExtension(Protocol):
"""Protocol for A2A wrapper extensions.
Extensions can implement this protocol to inject custom logic into
the A2A conversation flow at various integration points.
"""
def inject_tools(self, agent: Agent) -> None:
"""Inject extension-specific tools into the agent.
Called when an agent is wrapped with A2A capabilities. Extensions
can add tools that enable extension-specific functionality.
Args:
agent: The agent instance to inject tools into.
"""
...
def extract_state_from_history(
self, conversation_history: Sequence[Message]
) -> ConversationState | None:
"""Extract extension-specific state from conversation history.
Called during prompt augmentation to allow extensions to analyze
the conversation history and extract relevant state information.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Extension-specific conversation state, or None if no relevant state.
"""
...
def augment_prompt(
self,
base_prompt: str,
conversation_state: ConversationState | None,
) -> str:
"""Augment the task prompt with extension-specific instructions.
Called during prompt augmentation to allow extensions to add
custom instructions based on conversation state.
Args:
base_prompt: The base prompt to augment.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The augmented prompt with extension-specific instructions.
"""
...
def process_response(
self,
agent_response: Any,
conversation_state: ConversationState | None,
) -> Any:
"""Process and potentially modify the agent response.
Called after parsing the agent's response, allowing extensions to
enhance or modify the response based on conversation state.
Args:
agent_response: The parsed agent response.
conversation_state: Extension-specific state from extract_state_from_history.
Returns:
The processed agent response (may be modified or original).
"""
...
class ExtensionRegistry:
"""Registry for managing A2A extensions.
Maintains a collection of extensions and provides methods to invoke
their hooks at various integration points.
"""
def __init__(self) -> None:
"""Initialize the extension registry."""
self._extensions: list[A2AExtension] = []
def register(self, extension: A2AExtension) -> None:
"""Register an extension.
Args:
extension: The extension to register.
"""
self._extensions.append(extension)
def inject_all_tools(self, agent: Agent) -> None:
"""Inject tools from all registered extensions.
Args:
agent: The agent instance to inject tools into.
"""
for extension in self._extensions:
extension.inject_tools(agent)
def extract_all_states(
self, conversation_history: Sequence[Message]
) -> dict[type[A2AExtension], ConversationState]:
"""Extract conversation states from all registered extensions.
Args:
conversation_history: The sequence of A2A messages exchanged.
Returns:
Mapping of extension types to their conversation states.
"""
states: dict[type[A2AExtension], ConversationState] = {}
for extension in self._extensions:
state = extension.extract_state_from_history(conversation_history)
if state is not None:
states[type(extension)] = state
return states
def augment_prompt_with_all(
self,
base_prompt: str,
extension_states: dict[type[A2AExtension], ConversationState],
) -> str:
"""Augment prompt with instructions from all registered extensions.
Args:
base_prompt: The base prompt to augment.
extension_states: Mapping of extension types to conversation states.
Returns:
The fully augmented prompt.
"""
augmented = base_prompt
for extension in self._extensions:
state = extension_states.get(type(extension))
augmented = extension.augment_prompt(augmented, state)
return augmented
def process_response_with_all(
self,
agent_response: Any,
extension_states: dict[type[A2AExtension], ConversationState],
) -> Any:
"""Process response through all registered extensions.
Args:
agent_response: The parsed agent response.
extension_states: Mapping of extension types to conversation states.
Returns:
The processed agent response.
"""
processed = agent_response
for extension in self._extensions:
state = extension_states.get(type(extension))
processed = extension.process_response(processed, state)
return processed

View File

@@ -0,0 +1,34 @@
"""Extension registry factory for A2A configurations.
This module provides utilities for creating extension registries from A2A configurations.
"""
from __future__ import annotations
from typing import TYPE_CHECKING
from crewai.a2a.extensions.base import ExtensionRegistry
if TYPE_CHECKING:
from crewai.a2a.config import A2AConfig
def create_extension_registry_from_config(
a2a_config: list[A2AConfig] | A2AConfig,
) -> ExtensionRegistry:
"""Create an extension registry from A2A configuration.
Args:
a2a_config: A2A configuration (single or list)
Returns:
Configured extension registry with all applicable extensions
"""
registry = ExtensionRegistry()
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
for _ in configs:
pass
return registry

View File

@@ -23,6 +23,8 @@ from a2a.types import (
TextPart,
TransportProtocol,
)
from aiocache import cached # type: ignore[import-untyped]
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
import httpx
from pydantic import BaseModel, Field, create_model
@@ -65,7 +67,7 @@ def _fetch_agent_card_cached(
endpoint: A2A agent endpoint URL
auth_hash: Hash of the auth object
timeout: Request timeout
_ttl_hash: Time-based hash for cache invalidation (unused in body)
_ttl_hash: Time-based hash for cache invalidation
Returns:
Cached AgentCard
@@ -106,7 +108,18 @@ def fetch_agent_card(
A2AClientHTTPError: If authentication fails
"""
if use_cache:
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
ttl_hash = int(time.time() // cache_ttl)
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
@@ -121,6 +134,26 @@ def fetch_agent_card(
loop.close()
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
async def _fetch_agent_card_async_cached(
endpoint: str,
auth_hash: int,
timeout: int,
) -> AgentCard:
"""Cached async implementation of AgentCard fetching.
Args:
endpoint: A2A agent endpoint URL
auth_hash: Hash of the auth object
timeout: Request timeout in seconds
Returns:
Cached AgentCard object
"""
auth = _auth_store.get(auth_hash)
return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
async def _fetch_agent_card_async(
endpoint: str,
auth: AuthScheme | None,
@@ -339,7 +372,22 @@ async def _execute_a2a_delegation_async(
Returns:
Dictionary with status, result/error, and new history
"""
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
if auth:
auth_data = auth.model_dump_json(
exclude={
"_access_token",
"_token_expires_at",
"_refresh_token",
"_authorization_callback",
}
)
auth_hash = hash((type(auth).__name__, auth_data))
else:
auth_hash = 0
_auth_store[auth_hash] = auth
agent_card = await _fetch_agent_card_async_cached(
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
)
validate_auth_against_agent_card(agent_card, auth)
@@ -556,6 +604,34 @@ async def _execute_a2a_delegation_async(
}
break
except Exception as e:
if isinstance(e, A2AClientHTTPError):
error_msg = f"HTTP Error {e.status_code}: {e!s}"
error_message = Message(
role=Role.agent,
message_id=str(uuid.uuid4()),
parts=[Part(root=TextPart(text=error_msg))],
context_id=context_id,
task_id=task_id,
)
new_messages.append(error_message)
crewai_event_bus.emit(
None,
A2AResponseReceivedEvent(
response=error_msg,
turn_number=turn_number,
is_multiturn=is_multiturn,
status="failed",
agent_role=agent_role,
),
)
return {
"status": "failed",
"error": error_msg,
"history": new_messages,
}
current_exception: Exception | BaseException | None = e
while current_exception:
if hasattr(current_exception, "response"):
@@ -752,4 +828,5 @@ def get_a2a_agents_and_response_model(
Tuple of A2A agent IDs and response model
"""
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
return a2a_agents, create_agent_response_model(agent_ids)

View File

@@ -15,6 +15,7 @@ from a2a.types import Role
from pydantic import BaseModel, ValidationError
from crewai.a2a.config import A2AConfig
from crewai.a2a.extensions.base import ExtensionRegistry
from crewai.a2a.templates import (
AVAILABLE_AGENTS_TEMPLATE,
CONVERSATION_TURN_INFO_TEMPLATE,
@@ -42,7 +43,9 @@ if TYPE_CHECKING:
from crewai.tools.base_tool import BaseTool
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
def wrap_agent_with_a2a_instance(
agent: Agent, extension_registry: ExtensionRegistry | None = None
) -> None:
"""Wrap an agent instance's execute_task method with A2A support.
This function modifies the agent instance by wrapping its execute_task
@@ -51,7 +54,13 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
Args:
agent: The agent instance to wrap
extension_registry: Optional registry of A2A extensions for injecting tools and custom logic
"""
if extension_registry is None:
extension_registry = ExtensionRegistry()
extension_registry.inject_all_tools(agent)
original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined]
@wraps(original_execute_task)
@@ -85,6 +94,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
agent_response_model=agent_response_model,
context=context,
tools=tools,
extension_registry=extension_registry,
)
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
@@ -154,6 +164,7 @@ def _execute_task_with_a2a(
agent_response_model: type[BaseModel],
context: str | None,
tools: list[BaseTool] | None,
extension_registry: ExtensionRegistry,
) -> str:
"""Wrap execute_task with A2A delegation logic.
@@ -165,6 +176,7 @@ def _execute_task_with_a2a(
context: Optional context for task execution
tools: Optional tools available to the agent
agent_response_model: Optional agent response model
extension_registry: Registry of A2A extensions
Returns:
Task execution result (either from LLM or A2A agent)
@@ -190,11 +202,12 @@ def _execute_task_with_a2a(
finally:
task.description = original_description
task.description = _augment_prompt_with_a2a(
task.description, _ = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_description,
agent_cards=agent_cards,
failed_agents=failed_agents,
extension_registry=extension_registry,
)
task.response_model = agent_response_model
@@ -204,6 +217,11 @@ def _execute_task_with_a2a(
raw_result=raw_result, agent_response_model=agent_response_model
)
if extension_registry and isinstance(agent_response, BaseModel):
agent_response = extension_registry.process_response_with_all(
agent_response, {}
)
if isinstance(agent_response, BaseModel) and isinstance(
agent_response, AgentResponseProtocol
):
@@ -217,6 +235,7 @@ def _execute_task_with_a2a(
tools=tools,
agent_cards=agent_cards,
original_task_description=original_description,
extension_registry=extension_registry,
)
return str(agent_response.message)
@@ -235,7 +254,8 @@ def _augment_prompt_with_a2a(
turn_num: int = 0,
max_turns: int | None = None,
failed_agents: dict[str, str] | None = None,
) -> str:
extension_registry: ExtensionRegistry | None = None,
) -> tuple[str, bool]:
"""Add A2A delegation instructions to prompt.
Args:
@@ -246,13 +266,14 @@ def _augment_prompt_with_a2a(
turn_num: Current turn number (0-indexed)
max_turns: Maximum allowed turns (from config)
failed_agents: Dictionary mapping failed agent endpoints to error messages
extension_registry: Optional registry of A2A extensions
Returns:
Augmented task description with A2A instructions
Tuple of (augmented prompt, disable_structured_output flag)
"""
if not agent_cards:
return task_description
return task_description, False
agents_text = ""
@@ -270,6 +291,7 @@ def _augment_prompt_with_a2a(
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
history_text = ""
if conversation_history:
for msg in conversation_history:
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
@@ -277,6 +299,15 @@ def _augment_prompt_with_a2a(
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
previous_a2a_conversation=history_text
)
extension_states = {}
disable_structured_output = False
if extension_registry and conversation_history:
extension_states = extension_registry.extract_all_states(conversation_history)
for state in extension_states.values():
if state.is_ready():
disable_structured_output = True
break
turn_info = ""
if max_turns is not None and conversation_history:
@@ -296,16 +327,22 @@ def _augment_prompt_with_a2a(
warning=warning,
)
return f"""{task_description}
augmented_prompt = f"""{task_description}
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
{agents_text}
{history_text}{turn_info}
"""
if extension_registry:
augmented_prompt = extension_registry.augment_prompt_with_all(
augmented_prompt, extension_states
)
return augmented_prompt, disable_structured_output
def _parse_agent_response(
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
@@ -373,7 +410,7 @@ def _handle_agent_response_and_continue(
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
agent_cards_dict[agent_id] = a2a_result["agent_card"]
task.description = _augment_prompt_with_a2a(
task.description, disable_structured_output = _augment_prompt_with_a2a(
a2a_agents=a2a_agents,
task_description=original_task_description,
conversation_history=conversation_history,
@@ -382,7 +419,38 @@ def _handle_agent_response_and_continue(
agent_cards=agent_cards_dict,
)
original_response_model = task.response_model
if disable_structured_output:
task.response_model = None
raw_result = original_fn(self, task, context, tools)
if disable_structured_output:
task.response_model = original_response_model
if disable_structured_output:
final_turn_number = turn_num + 1
result_text = str(raw_result)
crewai_event_bus.emit(
None,
A2AMessageSentEvent(
message=result_text,
turn_number=final_turn_number,
is_multiturn=True,
agent_role=self.role,
),
)
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
status="completed",
final_result=result_text,
error=None,
total_turns=final_turn_number,
),
)
return result_text, None
llm_response = _parse_agent_response(
raw_result=raw_result, agent_response_model=agent_response_model
)
@@ -425,6 +493,7 @@ def _delegate_to_a2a(
tools: list[BaseTool] | None,
agent_cards: dict[str, AgentCard] | None = None,
original_task_description: str | None = None,
extension_registry: ExtensionRegistry | None = None,
) -> str:
"""Delegate to A2A agent with multi-turn conversation support.
@@ -437,6 +506,7 @@ def _delegate_to_a2a(
tools: Optional tools available to the agent
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
original_task_description: The original task description before A2A augmentation
extension_registry: Optional registry of A2A extensions
Returns:
Result from A2A agent
@@ -447,9 +517,13 @@ def _delegate_to_a2a(
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
agent_ids = tuple(config.endpoint for config in a2a_agents)
current_request = str(agent_response.message)
agent_id = agent_response.a2a_ids[0]
if agent_id not in agent_ids:
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
agent_id = agent_response.a2a_ids[0]
else:
agent_id = agent_ids[0] if agent_ids else ""
if agent_id and agent_id not in agent_ids:
raise ValueError(
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
)
@@ -458,10 +532,11 @@ def _delegate_to_a2a(
task_config = task.config or {}
context_id = task_config.get("context_id")
task_id_config = task_config.get("task_id")
reference_task_ids = task_config.get("reference_task_ids")
metadata = task_config.get("metadata")
extensions = task_config.get("extensions")
reference_task_ids = task_config.get("reference_task_ids", [])
if original_task_description is None:
original_task_description = task.description
@@ -497,11 +572,27 @@ def _delegate_to_a2a(
conversation_history = a2a_result.get("history", [])
if conversation_history:
latest_message = conversation_history[-1]
if latest_message.task_id is not None:
task_id_config = latest_message.task_id
if latest_message.context_id is not None:
context_id = latest_message.context_id
if a2a_result["status"] in ["completed", "input_required"]:
if (
a2a_result["status"] == "completed"
and agent_config.trust_remote_completion_status
):
if (
task_id_config is not None
and task_id_config not in reference_task_ids
):
reference_task_ids.append(task_id_config)
if task.config is None:
task.config = {}
task.config["reference_task_ids"] = reference_task_ids
result_text = a2a_result.get("result", "")
final_turn_number = turn_num + 1
crewai_event_bus.emit(
@@ -513,7 +604,7 @@ def _delegate_to_a2a(
total_turns=final_turn_number,
),
)
return result_text # type: ignore[no-any-return]
return cast(str, result_text)
final_result, next_request = _handle_agent_response_and_continue(
self=self,
@@ -541,6 +632,31 @@ def _delegate_to_a2a(
continue
error_msg = a2a_result.get("error", "Unknown error")
final_result, next_request = _handle_agent_response_and_continue(
self=self,
a2a_result=a2a_result,
agent_id=agent_id,
agent_cards=agent_cards,
a2a_agents=a2a_agents,
original_task_description=original_task_description,
conversation_history=conversation_history,
turn_num=turn_num,
max_turns=max_turns,
task=task,
original_fn=original_fn,
context=context,
tools=tools,
agent_response_model=agent_response_model,
)
if final_result is not None:
return final_result
if next_request is not None:
current_request = next_request
continue
crewai_event_bus.emit(
None,
A2AConversationCompletedEvent(
@@ -550,7 +666,7 @@ def _delegate_to_a2a(
total_turns=turn_num + 1,
),
)
raise Exception(f"A2A delegation failed: {error_msg}")
return f"A2A delegation failed: {error_msg}"
if conversation_history:
for msg in reversed(conversation_history):

View File

@@ -2,7 +2,6 @@ from __future__ import annotations
import asyncio
from collections.abc import Sequence
import json
import shutil
import subprocess
import time
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
from typing_extensions import Self
from crewai.a2a.config import A2AConfig
from crewai.agent.utils import (
ahandle_knowledge_retrieval,
apply_training_data,
build_task_prompt_with_schema,
format_task_with_context,
get_knowledge_config,
handle_knowledge_retrieval,
handle_reasoning,
prepare_tools,
process_tool_results,
save_last_messages,
validate_max_execution_time,
)
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.agents.crew_agent_executor import CrewAgentExecutor
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
KnowledgeQueryCompletedEvent,
KnowledgeQueryFailedEvent,
KnowledgeQueryStartedEvent,
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.events.types.memory_events import (
MemoryRetrievalCompletedEvent,
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
)
from crewai.knowledge.knowledge import Knowledge
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.mcp import (
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
render_text_description_and_args,
)
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
from crewai.utilities.converter import Converter, generate_model_description
from crewai.utilities.converter import Converter
from crewai.utilities.guardrail_types import GuardrailType
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.prompts import Prompts
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
if self.reasoning:
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=self)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
# Add the reasoning plan to the task description
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
self._logger.log("error", f"Error during reasoning process: {e!s}")
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
# If the task requires output in JSON or Pydantic format,
# append specific instructions to the task prompt to ensure
# that the final answer does not include any code block markers
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
if (task.output_json or task.output_pydantic) and not task.response_model:
# Generate the schema based on the output format
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + self.i18n.slice(
"formatted_task_instructions"
).format(output_format=schema)
if context:
task_prompt = self.i18n.slice("task_with_context").format(
task=task_prompt, context=context
)
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
from_task=task,
),
)
knowledge_config = (
self.knowledge_config.model_dump() if self.knowledge_config else {}
knowledge_config = get_knowledge_config(self)
task_prompt = handle_knowledge_retrieval(
self,
task,
task_prompt,
knowledge_config,
self.knowledge.query if self.knowledge else lambda *a, **k: None,
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
)
if self.knowledge or (self.crew and self.crew.knowledge):
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=self,
),
)
try:
self.knowledge_search_query = self._get_knowledge_search_query(
task_prompt, task
)
if self.knowledge_search_query:
# Quering agent specific knowledge
if self.knowledge:
agent_knowledge_snippets = self.knowledge.query(
[self.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
self.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if self.agent_knowledge_context:
task_prompt += self.agent_knowledge_context
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
# Quering crew specific knowledge
knowledge_snippets = self.crew.query_knowledge(
[self.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
self.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if self.crew_knowledge_context:
task_prompt += self.crew_knowledge_context
crewai_event_bus.emit(
self,
event=KnowledgeRetrievalCompletedEvent(
query=self.knowledge_search_query,
from_task=task,
from_agent=self,
retrieved_knowledge=(
(self.agent_knowledge_context or "")
+ (
"\n"
if self.agent_knowledge_context
and self.crew_knowledge_context
else ""
)
+ (self.crew_knowledge_context or "")
),
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=KnowledgeSearchQueryFailedEvent(
query=self.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=self,
),
)
tools = tools or self.tools or []
self.create_agent_executor(tools=tools, task=task)
if self.crew and self.crew._train:
task_prompt = self._training_handler(task_prompt=task_prompt)
else:
task_prompt = self._use_trained_data(task_prompt=task_prompt)
# Import agent events locally to avoid circular imports
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
),
)
# Determine execution method based on timeout setting
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
if (
not isinstance(self.max_execution_time, int)
or self.max_execution_time <= 0
):
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
result = self._execute_with_timeout(
task_prompt, task, self.max_execution_time
)
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
result = self._execute_without_timeout(task_prompt, task)
except TimeoutError as e:
# Propagate TimeoutError without retry
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
# Do not retry on litellm errors
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
# If there was any tool in self.tools_results that had result_as_answer
# set to True, return the results of the last tool that had
# result_as_answer set to True
for tool_result in self.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
self._last_messages = (
self.agent_executor.messages.copy()
if self.agent_executor and hasattr(self.agent_executor, "messages")
else []
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
}
)["output"]
async def aexecute_task(
self,
task: Task,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> Any:
"""Execute a task with the agent asynchronously.
Args:
task: Task to execute.
context: Context to execute the task in.
tools: Tools to use for the task.
Returns:
Output of the agent.
Raises:
TimeoutError: If execution exceeds the maximum execution time.
ValueError: If the max execution time is not a positive integer.
RuntimeError: If the agent execution fails for other reasons.
"""
handle_reasoning(self, task)
self._inject_date_to_task(task)
if self.tools_handler:
self.tools_handler.last_used_tool = None
task_prompt = task.prompt()
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
if self._is_any_available_memory():
crewai_event_bus.emit(
self,
event=MemoryRetrievalStartedEvent(
task_id=str(task.id) if task else None,
source_type="agent",
from_agent=self,
from_task=task,
),
)
start_time = time.time()
contextual_memory = ContextualMemory(
self.crew._short_term_memory,
self.crew._long_term_memory,
self.crew._entity_memory,
self.crew._external_memory,
agent=self,
task=task,
)
memory = await contextual_memory.abuild_context_for_task(
task, context or ""
)
if memory.strip() != "":
task_prompt += self.i18n.slice("memory").format(memory=memory)
crewai_event_bus.emit(
self,
event=MemoryRetrievalCompletedEvent(
task_id=str(task.id) if task else None,
memory_content=memory,
retrieval_time_ms=(time.time() - start_time) * 1000,
source_type="agent",
from_agent=self,
from_task=task,
),
)
knowledge_config = get_knowledge_config(self)
task_prompt = await ahandle_knowledge_retrieval(
self, task, task_prompt, knowledge_config
)
prepare_tools(self, tools, task)
task_prompt = apply_training_data(self, task_prompt)
from crewai.events.types.agent_events import (
AgentExecutionCompletedEvent,
AgentExecutionErrorEvent,
AgentExecutionStartedEvent,
)
try:
crewai_event_bus.emit(
self,
event=AgentExecutionStartedEvent(
agent=self,
tools=self.tools,
task_prompt=task_prompt,
task=task,
),
)
validate_max_execution_time(self.max_execution_time)
if self.max_execution_time is not None:
result = await self._aexecute_with_timeout(
task_prompt, task, self.max_execution_time
)
else:
result = await self._aexecute_without_timeout(task_prompt, task)
except TimeoutError as e:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
self._times_executed += 1
if self._times_executed > self.max_retry_limit:
crewai_event_bus.emit(
self,
event=AgentExecutionErrorEvent(
agent=self,
task=task,
error=str(e),
),
)
raise e
result = await self.aexecute_task(task, context, tools)
if self.max_rpm and self._rpm_controller:
self._rpm_controller.stop_rpm_counter()
result = process_tool_results(self, result)
crewai_event_bus.emit(
self,
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
)
save_last_messages(self)
self._cleanup_mcp_clients()
return result
async def _aexecute_with_timeout(
self, task_prompt: str, task: Task, timeout: int
) -> Any:
"""Execute a task with a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
timeout: Maximum execution time in seconds.
Returns:
The output of the agent.
Raises:
TimeoutError: If execution exceeds the timeout.
RuntimeError: If execution fails for other reasons.
"""
try:
return await asyncio.wait_for(
self._aexecute_without_timeout(task_prompt, task),
timeout=timeout,
)
except asyncio.TimeoutError as e:
raise TimeoutError(
f"Task '{task.description}' execution timed out after {timeout} seconds. "
"Consider increasing max_execution_time or optimizing the task."
) from e
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
"""Execute a task without a timeout asynchronously.
Args:
task_prompt: The prompt to send to the agent.
task: The task being executed.
Returns:
The output of the agent.
"""
if not self.agent_executor:
raise RuntimeError("Agent executor is not initialized.")
result = await self.agent_executor.ainvoke(
{
"input": task_prompt,
"tool_names": self.agent_executor.tools_names,
"tools": self.agent_executor.tools_description,
"ask_for_human_input": task.human_input,
}
)
return result["output"]
def create_agent_executor(
self, tools: list[BaseTool] | None = None, task: Task | None = None
) -> None:
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
)
self.agent_executor = CrewAgentExecutor(
llm=self.llm,
llm=self.llm, # type: ignore[arg-type]
task=task, # type: ignore[arg-type]
agent=self,
crew=self.crew,
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
from crewai.tools.base_tool import BaseTool
from crewai.tools.mcp_native_tool import MCPNativeTool
transport: StdioTransport | HTTPTransport | SSETransport
if isinstance(mcp_config, MCPServerStdio):
transport = StdioTransport(
command=mcp_config.command,
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
server_name=server_name,
run_context=None,
)
if mcp_config.tool_filter(context, tool):
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
except (TypeError, AttributeError):
if mcp_config.tool_filter(tool):
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
filtered_tools.append(tool)
else:
# Not callable - include tool
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
path = parsed.path.replace("/", "_").strip("_")
return f"{domain}_{path}" if path else domain
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
def _get_mcp_tool_schemas(
self, server_params: dict[str, Any]
) -> dict[str, dict[str, Any]]:
"""Get tool schemas from MCP server for wrapper creation with caching."""
server_url = server_params["url"]
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
self._logger.log(
"debug", f"Using cached MCP tool schemas for {server_url}"
)
return cached_data
return cached_data # type: ignore[no-any-return]
try:
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
async def _get_mcp_tool_schemas_async(
self, server_params: dict[str, Any]
) -> dict[str, dict]:
) -> dict[str, dict[str, Any]]:
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
server_url = server_params["url"]
return await self._retry_mcp_discovery(
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
)
async def _retry_mcp_discovery(
self, operation_func, server_url: str
self, operation_func: Any, server_url: str
) -> dict[str, dict[str, Any]]:
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
last_error = None
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
@staticmethod
async def _attempt_mcp_discovery(
operation_func, server_url: str
operation_func: Any, server_url: str
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
try:
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
properties = json_schema.get("properties", {})
required_fields = json_schema.get("required", [])
field_definitions = {}
field_definitions: dict[str, Any] = {}
for field_name, field_schema in properties.items():
field_type = self._json_type_to_python(field_schema)
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
)
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
return create_model(model_name, **field_definitions)
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
"""Convert JSON Schema type to Python type.
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
json_type = field_schema.get("type")
if "anyOf" in field_schema:
types = []
types: list[type] = []
for option in field_schema["anyOf"]:
if "const" in option:
types.append(str)
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
types.append(self._json_type_to_python(option))
unique_types = list(set(types))
if len(unique_types) > 1:
result = unique_types[0]
result: Any = unique_types[0]
for t in unique_types[1:]:
result = result | t
return result
return result # type: ignore[no-any-return]
return unique_types[0]
type_mapping = {
type_mapping: dict[str | None, type] = {
"string": str,
"number": float,
"integer": int,
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
return type_mapping.get(json_type, Any)
@staticmethod
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
"""Fetch MCP server configurations from CrewAI AOP API."""
# TODO: Implement AMP API call to "integrations/mcps" endpoint
# Should return list of server configs with URLs
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
"""
if self.apps:
platform_tools = self.get_platform_tools(self.apps)
if platform_tools:
if platform_tools and self.tools is not None:
self.tools.extend(platform_tools)
if self.mcps:
mcps = self.get_mcp_tools(self.mcps)
if mcps:
if mcps and self.tools is not None:
self.tools.extend(mcps)
lite_agent = LiteAgent(

View File

@@ -4,9 +4,8 @@ This metaclass enables extension capabilities for agents by detecting
extension fields in class annotations and applying appropriate wrappers.
"""
import warnings
from functools import wraps
from typing import Any
import warnings
from pydantic import model_validator
from pydantic._internal._model_construction import ModelMetaclass
@@ -59,9 +58,15 @@ class AgentMeta(ModelMetaclass):
a2a_value = getattr(self, "a2a", None)
if a2a_value is not None:
from crewai.a2a.extensions.registry import (
create_extension_registry_from_config,
)
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
wrap_agent_with_a2a_instance(self)
extension_registry = create_extension_registry_from_config(
a2a_value
)
wrap_agent_with_a2a_instance(self, extension_registry)
return result

View File

@@ -0,0 +1,355 @@
"""Utility functions for agent task execution.
This module contains shared logic extracted from the Agent's execute_task
and aexecute_task methods to reduce code duplication.
"""
from __future__ import annotations
import json
from typing import TYPE_CHECKING, Any
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.knowledge_events import (
KnowledgeRetrievalCompletedEvent,
KnowledgeRetrievalStartedEvent,
KnowledgeSearchQueryFailedEvent,
)
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
from crewai.utilities.converter import generate_model_description
if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.i18n import I18N
def handle_reasoning(agent: Agent, task: Task) -> None:
"""Handle the reasoning process for an agent before task execution.
Args:
agent: The agent performing the task.
task: The task to execute.
"""
if not agent.reasoning:
return
try:
from crewai.utilities.reasoning_handler import (
AgentReasoning,
AgentReasoningOutput,
)
reasoning_handler = AgentReasoning(task=task, agent=agent)
reasoning_output: AgentReasoningOutput = (
reasoning_handler.handle_agent_reasoning()
)
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
except Exception as e:
agent._logger.log("error", f"Error during reasoning process: {e!s}")
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
Args:
task: The task being executed.
task_prompt: The initial task prompt.
i18n: Internationalization instance.
Returns:
The task prompt potentially augmented with schema instructions.
"""
if (task.output_json or task.output_pydantic) and not task.response_model:
if task.output_json:
schema_dict = generate_model_description(task.output_json)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
elif task.output_pydantic:
schema_dict = generate_model_description(task.output_pydantic)
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
output_format=schema
)
return task_prompt
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
"""Format task prompt with context if provided.
Args:
task_prompt: The task prompt.
context: Optional context string.
i18n: Internationalization instance.
Returns:
The task prompt formatted with context if provided.
"""
if context:
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
return task_prompt
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
"""Get knowledge configuration from agent.
Args:
agent: The agent instance.
Returns:
Dictionary of knowledge configuration.
"""
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
def handle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
query_func: Any,
crew_query_func: Any,
) -> str:
"""Handle knowledge retrieval for task execution.
This function handles both agent-specific and crew-specific knowledge queries.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
query_func: Function to query agent knowledge (sync or async).
crew_query_func: Function to query crew knowledge (sync or async).
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = query_func(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = crew_query_func(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt
def _combine_knowledge_context(agent: Agent) -> str:
"""Combine agent and crew knowledge contexts into a single string.
Args:
agent: The agent with knowledge contexts.
Returns:
Combined knowledge context string.
"""
agent_ctx = agent.agent_knowledge_context or ""
crew_ctx = agent.crew_knowledge_context or ""
separator = "\n" if agent_ctx and crew_ctx else ""
return agent_ctx + separator + crew_ctx
def apply_training_data(agent: Agent, task_prompt: str) -> str:
"""Apply training data to the task prompt.
Args:
agent: The agent performing the task.
task_prompt: The task prompt.
Returns:
The task prompt with training data applied.
"""
if agent.crew and agent.crew._train:
return agent._training_handler(task_prompt=task_prompt)
return agent._use_trained_data(task_prompt=task_prompt)
def process_tool_results(agent: Agent, result: Any) -> Any:
"""Process tool results, returning result_as_answer if applicable.
Args:
agent: The agent with tool results.
result: The current result.
Returns:
The final result, potentially overridden by tool result_as_answer.
"""
for tool_result in agent.tools_results:
if tool_result.get("result_as_answer", False):
result = tool_result["result"]
return result
def save_last_messages(agent: Agent) -> None:
"""Save the last messages from agent executor.
Args:
agent: The agent instance.
"""
agent._last_messages = (
agent.agent_executor.messages.copy()
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
else []
)
def prepare_tools(
agent: Agent, tools: list[BaseTool] | None, task: Task
) -> list[BaseTool]:
"""Prepare tools for task execution and create agent executor.
Args:
agent: The agent instance.
tools: Optional list of tools.
task: The task being executed.
Returns:
The list of tools to use.
"""
final_tools = tools or agent.tools or []
agent.create_agent_executor(tools=final_tools, task=task)
return final_tools
def validate_max_execution_time(max_execution_time: int | None) -> None:
"""Validate max_execution_time parameter.
Args:
max_execution_time: The maximum execution time to validate.
Raises:
ValueError: If max_execution_time is not a positive integer.
"""
if max_execution_time is not None:
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
raise ValueError(
"Max Execution time must be a positive integer greater than zero"
)
async def ahandle_knowledge_retrieval(
agent: Agent,
task: Task,
task_prompt: str,
knowledge_config: dict[str, Any],
) -> str:
"""Handle async knowledge retrieval for task execution.
Args:
agent: The agent performing the task.
task: The task being executed.
task_prompt: The current task prompt.
knowledge_config: Knowledge configuration dictionary.
Returns:
The task prompt potentially augmented with knowledge context.
"""
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
return task_prompt
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalStartedEvent(
from_task=task,
from_agent=agent,
),
)
try:
agent.knowledge_search_query = agent._get_knowledge_search_query(
task_prompt, task
)
if agent.knowledge_search_query:
if agent.knowledge:
agent_knowledge_snippets = await agent.knowledge.aquery(
[agent.knowledge_search_query], **knowledge_config
)
if agent_knowledge_snippets:
agent.agent_knowledge_context = extract_knowledge_context(
agent_knowledge_snippets
)
if agent.agent_knowledge_context:
task_prompt += agent.agent_knowledge_context
knowledge_snippets = await agent.crew.aquery_knowledge(
[agent.knowledge_search_query], **knowledge_config
)
if knowledge_snippets:
agent.crew_knowledge_context = extract_knowledge_context(
knowledge_snippets
)
if agent.crew_knowledge_context:
task_prompt += agent.crew_knowledge_context
crewai_event_bus.emit(
agent,
event=KnowledgeRetrievalCompletedEvent(
query=agent.knowledge_search_query,
from_task=task,
from_agent=agent,
retrieved_knowledge=_combine_knowledge_context(agent),
),
)
except Exception as e:
crewai_event_bus.emit(
agent,
event=KnowledgeSearchQueryFailedEvent(
query=agent.knowledge_search_query or "",
error=str(e),
from_task=task,
from_agent=agent,
),
)
return task_prompt

View File

@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
if not mcps:
return mcps
validated_mcps = []
validated_mcps: list[str | MCPServerConfig] = []
for mcp in mcps:
if isinstance(mcp, str):
if mcp.startswith(("https://", "crewai-amp:")):
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
) -> str:
pass
@abstractmethod
async def aexecute_task(
self,
task: Any,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> str:
"""Execute a task asynchronously."""
@abstractmethod
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
pass

View File

@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
get_before_llm_call_hooks,
)
from crewai.utilities.agent_utils import (
aget_llm_response,
enforce_rpm_limit,
format_message_for_llm,
get_llm_response,
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
from crewai.utilities.constants import TRAINING_DATA_FILE
from crewai.utilities.i18n import I18N, get_i18n
from crewai.utilities.printer import Printer
from crewai.utilities.tool_utils import execute_tool_and_check_finality
from crewai.utilities.tool_utils import (
aexecute_tool_and_check_finality,
execute_tool_and_check_finality,
)
from crewai.utilities.training_handler import CrewTrainingHandler
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self.messages: list[LLMMessage] = []
self.iterations = 0
self.log_error_after = 3
self.before_llm_call_hooks: list[Callable] = []
self.after_llm_call_hooks: list[Callable] = []
self.before_llm_call_hooks: list[Callable[..., Any]] = []
self.after_llm_call_hooks: list[Callable[..., Any]] = []
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
if self.llm:
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
self._show_logs(formatted_answer)
return formatted_answer
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
"""Execute the agent asynchronously with given inputs.
Args:
inputs: Input dictionary containing prompt variables.
Returns:
Dictionary with agent output.
"""
if "system" in self.prompt:
system_prompt = self._format_prompt(
cast(str, self.prompt.get("system", "")), inputs
)
user_prompt = self._format_prompt(
cast(str, self.prompt.get("user", "")), inputs
)
self.messages.append(format_message_for_llm(system_prompt, role="system"))
self.messages.append(format_message_for_llm(user_prompt))
else:
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
self.messages.append(format_message_for_llm(user_prompt))
self._show_start_logs()
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
try:
formatted_answer = await self._ainvoke_loop()
except AssertionError:
self._printer.print(
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
color="red",
)
raise
except Exception as e:
handle_unknown_error(self._printer, e)
raise
if self.ask_for_human_input:
formatted_answer = self._handle_human_feedback(formatted_answer)
self._create_short_term_memory(formatted_answer)
self._create_long_term_memory(formatted_answer)
self._create_external_memory(formatted_answer)
return {"output": formatted_answer.output}
async def _ainvoke_loop(self) -> AgentFinish:
"""Execute agent loop asynchronously until completion.
Returns:
Final answer from the agent.
"""
formatted_answer = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self.iterations, self.max_iter):
formatted_answer = handle_max_iterations_exceeded(
formatted_answer,
printer=self._printer,
i18n=self._i18n,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
)
break
enforce_rpm_limit(self.request_within_rpm_limit)
answer = await aget_llm_response(
llm=self.llm,
messages=self.messages,
callbacks=self.callbacks,
printer=self._printer,
from_task=self.task,
from_agent=self.agent,
response_model=self.response_model,
executor_context=self,
)
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
if isinstance(formatted_answer, AgentAction):
fingerprint_context = {}
if (
self.agent
and hasattr(self.agent, "security_config")
and hasattr(self.agent.security_config, "fingerprint")
):
fingerprint_context = {
"agent_fingerprint": str(
self.agent.security_config.fingerprint
)
}
tool_result = await aexecute_tool_and_check_finality(
agent_action=formatted_answer,
fingerprint_context=fingerprint_context,
tools=self.tools,
i18n=self._i18n,
agent_key=self.agent.key if self.agent else None,
agent_role=self.agent.role if self.agent else None,
tools_handler=self.tools_handler,
task=self.task,
agent=self.agent,
function_calling_llm=self.function_calling_llm,
crew=self.crew,
)
formatted_answer = self._handle_agent_action(
formatted_answer, tool_result
)
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
except OutputParserError as e:
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
e=e,
messages=self.messages,
iterations=self.iterations,
log_error_after=self.log_error_after,
printer=self._printer,
)
except Exception as e:
if e.__class__.__module__.startswith("litellm"):
raise e
if is_context_length_exceeded(e):
handle_context_length(
respect_context_window=self.respect_context_window,
printer=self._printer,
messages=self.messages,
llm=self.llm,
callbacks=self.callbacks,
i18n=self._i18n,
)
continue
handle_unknown_error(self._printer, e)
raise e
finally:
self.iterations += 1
if not isinstance(formatted_answer, AgentFinish):
raise RuntimeError(
"Agent execution ended without reaching a final answer. "
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
)
self._show_logs(formatted_answer)
return formatted_answer
def _handle_agent_action(
self, formatted_answer: AgentAction, tool_result: ToolResult
) -> AgentAction | AgentFinish:

View File

@@ -14,7 +14,8 @@ import tomli
from crewai.cli.utils import read_toml
from crewai.cli.version import get_crewai_version
from crewai.crew import Crew
from crewai.llm import LLM, BaseLLM
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
from crewai.types.crew_chat import ChatInputField, ChatInputs
from crewai.utilities.llm_utils import create_llm
from crewai.utilities.printer import Printer
@@ -27,7 +28,7 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
def check_conversational_crews_version(
crewai_version: str, pyproject_data: dict
crewai_version: str, pyproject_data: dict[str, Any]
) -> bool:
"""
Check if the installed crewAI version supports conversational crews.
@@ -53,7 +54,7 @@ def check_conversational_crews_version(
return True
def run_chat():
def run_chat() -> None:
"""
Runs an interactive chat loop using the Crew's chat LLM with function calling.
Incorporates crew_name, crew_description, and input fields to build a tool schema.
@@ -101,7 +102,7 @@ def run_chat():
click.secho(f"Assistant: {introductory_message}\n", fg="green")
messages = [
messages: list[LLMMessage] = [
{"role": "system", "content": system_message},
{"role": "assistant", "content": introductory_message},
]
@@ -113,7 +114,7 @@ def run_chat():
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
def show_loading(event: threading.Event):
def show_loading(event: threading.Event) -> None:
"""Display animated loading dots while processing."""
while not event.is_set():
_printer.print(".", end="")
@@ -162,23 +163,23 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
)
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
"""Creates a wrapper function for running the crew tool with messages."""
def run_crew_tool_with_messages(**kwargs):
def run_crew_tool_with_messages(**kwargs: Any) -> str:
return run_crew_tool(crew, messages, **kwargs)
return run_crew_tool_with_messages
def flush_input():
def flush_input() -> None:
"""Flush any pending input from the user."""
if platform.system() == "Windows":
# Windows platform
import msvcrt
while msvcrt.kbhit():
msvcrt.getch()
while msvcrt.kbhit(): # type: ignore[attr-defined]
msvcrt.getch() # type: ignore[attr-defined]
else:
# Unix-like platforms (Linux, macOS)
import termios
@@ -186,7 +187,12 @@ def flush_input():
termios.tcflush(sys.stdin, termios.TCIFLUSH)
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
def chat_loop(
chat_llm: LLM | BaseLLM,
messages: list[LLMMessage],
crew_tool_schema: dict[str, Any],
available_functions: dict[str, Any],
) -> None:
"""Main chat loop for interacting with the user."""
while True:
try:
@@ -225,7 +231,7 @@ def get_user_input() -> str:
def handle_user_input(
user_input: str,
chat_llm: LLM,
chat_llm: LLM | BaseLLM,
messages: list[LLMMessage],
crew_tool_schema: dict[str, Any],
available_functions: dict[str, Any],
@@ -255,7 +261,7 @@ def handle_user_input(
click.secho(f"\nAssistant: {final_response}\n", fg="green")
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
"""
Dynamically build a Littellm 'function' schema for the given crew.
@@ -286,7 +292,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
}
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
"""
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
@@ -372,7 +378,9 @@ def load_crew_and_name() -> tuple[Crew, str]:
return crew_instance, crew_class_name
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
def generate_crew_chat_inputs(
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
) -> ChatInputs:
"""
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
@@ -410,23 +418,12 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
Returns:
Set[str]: A set of placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)}")
required_inputs: set[str] = set()
# Scan tasks
for task in crew.tasks:
text = f"{task.description or ''} {task.expected_output or ''}"
required_inputs.update(placeholder_pattern.findall(text))
# Scan agents
for agent in crew.agents:
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
required_inputs.update(placeholder_pattern.findall(text))
return required_inputs
return crew.fetch_inputs()
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
def generate_input_description_with_ai(
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
) -> str:
"""
Generates an input description using AI based on the context of the crew.
@@ -484,10 +481,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
return response.strip()
return str(response).strip()
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
"""
Generates a brief description of the crew using AI.
@@ -534,4 +531,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
f"{context}"
)
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
return response.strip()
return str(response).strip()

View File

@@ -3,103 +3,56 @@ import json
import os
from pathlib import Path
import sys
from typing import BinaryIO, cast
import tempfile
from typing import Final, Literal, cast
from cryptography.fernet import Fernet
if sys.platform == "win32":
import msvcrt
else:
import fcntl
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
class TokenManager:
def __init__(self, file_path: str = "tokens.enc") -> None:
"""
Initialize the TokenManager class.
"""Manages encrypted token storage."""
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
def __init__(self, file_path: str = "tokens.enc") -> None:
"""Initialize the TokenManager.
Args:
file_path: The file path to store encrypted tokens.
"""
self.file_path = file_path
self.key = self._get_or_create_key()
self.fernet = Fernet(self.key)
@staticmethod
def _acquire_lock(file_handle: BinaryIO) -> None:
"""
Acquire an exclusive lock on a file handle.
Args:
file_handle: Open file handle to lock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
@staticmethod
def _release_lock(file_handle: BinaryIO) -> None:
"""
Release the lock on a file handle.
Args:
file_handle: Open file handle to unlock.
"""
if sys.platform == "win32":
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
else:
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
def _get_or_create_key(self) -> bytes:
"""
Get or create the encryption key with file locking to prevent race conditions.
"""Get or create the encryption key.
Returns:
The encryption key.
The encryption key as bytes.
"""
key_filename = "secret.key"
storage_path = self.get_secure_storage_path()
key_filename: str = "secret.key"
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
lock_file_path = storage_path / f"{key_filename}.lock"
try:
lock_file_path.touch()
with open(lock_file_path, "r+b") as lock_file:
self._acquire_lock(lock_file)
try:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
return new_key
finally:
try:
self._release_lock(lock_file)
except OSError:
pass
except OSError:
key = self.read_secure_file(key_filename)
if key is not None and len(key) == 44:
return key
new_key = Fernet.generate_key()
self.save_secure_file(key_filename, new_key)
new_key = Fernet.generate_key()
if self._atomic_create_secure_file(key_filename, new_key):
return new_key
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""
Save the access token and its expiration time.
key = self._read_secure_file(key_filename)
if key is not None and len(key) == _FERNET_KEY_LENGTH:
return key
:param access_token: The access token to save.
:param expires_at: The UNIX timestamp of the expiration time.
raise RuntimeError("Failed to create or read encryption key")
def save_tokens(self, access_token: str, expires_at: int) -> None:
"""Save the access token and its expiration time.
Args:
access_token: The access token to save.
expires_at: The UNIX timestamp of the expiration time.
"""
expiration_time = datetime.fromtimestamp(expires_at)
data = {
@@ -107,15 +60,15 @@ class TokenManager:
"expiration": expiration_time.isoformat(),
}
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
self.save_secure_file(self.file_path, encrypted_data)
self._atomic_write_secure_file(self.file_path, encrypted_data)
def get_token(self) -> str | None:
"""
Get the access token if it is valid and not expired.
"""Get the access token if it is valid and not expired.
:return: The access token if valid and not expired, otherwise None.
Returns:
The access token if valid and not expired, otherwise None.
"""
encrypted_data = self.read_secure_file(self.file_path)
encrypted_data = self._read_secure_file(self.file_path)
if encrypted_data is None:
return None
@@ -126,20 +79,18 @@ class TokenManager:
if expiration <= datetime.now():
return None
return cast(str | None, data["access_token"])
return cast(str | None, data.get("access_token"))
def clear_tokens(self) -> None:
"""
Clear the tokens.
"""
self.delete_secure_file(self.file_path)
"""Clear the stored tokens."""
self._delete_secure_file(self.file_path)
@staticmethod
def get_secure_storage_path() -> Path:
"""
Get the secure storage path based on the operating system.
def _get_secure_storage_path() -> Path:
"""Get the secure storage path based on the operating system.
:return: The secure storage path.
Returns:
The secure storage path.
"""
if sys.platform == "win32":
base_path = os.environ.get("LOCALAPPDATA")
@@ -155,44 +106,81 @@ class TokenManager:
return storage_path
def save_secure_file(self, filename: str, content: bytes) -> None:
"""
Save the content to a secure file.
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
"""Create a file only if it doesn't exist.
:param filename: The name of the file.
:param content: The content to save.
Args:
filename: The name of the file.
content: The content to write.
Returns:
True if file was created, False if it already exists.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
with open(file_path, "wb") as f:
f.write(content)
try:
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
try:
os.write(fd, content)
finally:
os.close(fd)
return True
except FileExistsError:
return False
os.chmod(file_path, 0o600)
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
"""Write content to a secure file.
def read_secure_file(self, filename: str) -> bytes | None:
Args:
filename: The name of the file.
content: The content to write.
"""
Read the content of a secure file.
:param filename: The name of the file.
:return: The content of the file if it exists, otherwise None.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
if not file_path.exists():
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
fd_closed = False
try:
os.write(fd, content)
os.close(fd)
fd_closed = True
os.chmod(temp_path, 0o600)
os.replace(temp_path, file_path)
except Exception:
if not fd_closed:
os.close(fd)
if os.path.exists(temp_path):
os.unlink(temp_path)
raise
def _read_secure_file(self, filename: str) -> bytes | None:
"""Read the content of a secure file.
Args:
filename: The name of the file.
Returns:
The content of the file if it exists, otherwise None.
"""
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
try:
with open(file_path, "rb") as f:
return f.read()
except FileNotFoundError:
return None
with open(file_path, "rb") as f:
return f.read()
def _delete_secure_file(self, filename: str) -> None:
"""Delete a secure file.
def delete_secure_file(self, filename: str) -> None:
Args:
filename: The name of the file.
"""
Delete the secure file.
:param filename: The name of the file.
"""
storage_path = self.get_secure_storage_path()
storage_path = self._get_secure_storage_path()
file_path = storage_path / filename
if file_path.exists():
file_path.unlink(missing_ok=True)
try:
file_path.unlink()
except FileNotFoundError:
pass

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.1"
"crewai[tools]==1.7.0"
]
[project.scripts]

View File

@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
authors = [{ name = "Your Name", email = "you@example.com" }]
requires-python = ">=3.10,<3.14"
dependencies = [
"crewai[tools]==1.6.1"
"crewai[tools]==1.7.0"
]
[project.scripts]

View File

@@ -35,6 +35,14 @@ from crewai.agent import Agent
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.agents.cache.cache_handler import CacheHandler
from crewai.crews.crew_output import CrewOutput
from crewai.crews.utils import (
StreamingContext,
check_conditional_skip,
enable_agent_streaming,
prepare_kickoff,
prepare_task_execution,
run_for_each_async,
)
from crewai.events.event_bus import crewai_event_bus
from crewai.events.event_listener import EventListener
from crewai.events.listeners.tracing.trace_listener import (
@@ -47,7 +55,6 @@ from crewai.events.listeners.tracing.utils import (
from crewai.events.types.crew_events import (
CrewKickoffCompletedEvent,
CrewKickoffFailedEvent,
CrewKickoffStartedEvent,
CrewTestCompletedEvent,
CrewTestFailedEvent,
CrewTestStartedEvent,
@@ -74,7 +81,7 @@ from crewai.tasks.conditional_task import ConditionalTask
from crewai.tasks.task_output import TaskOutput
from crewai.tools.agent_tools.agent_tools import AgentTools
from crewai.tools.base_tool import BaseTool
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
from crewai.types.streaming import CrewStreamingOutput
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
from crewai.utilities.crew.models import CrewContext
@@ -92,10 +99,8 @@ from crewai.utilities.planning_handler import CrewPlanner
from crewai.utilities.printer import PrinterColor
from crewai.utilities.rpm_controller import RPMController
from crewai.utilities.streaming import (
TaskInfo,
create_async_chunk_generator,
create_chunk_generator,
create_streaming_state,
signal_end,
signal_error,
)
@@ -268,7 +273,7 @@ class Crew(FlowTrackable, BaseModel):
description="list of file paths for task execution JSON files.",
)
execution_logs: list[dict[str, Any]] = Field(
default=[],
default_factory=list,
description="list of execution logs for tasks",
)
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
@@ -348,12 +353,12 @@ class Crew(FlowTrackable, BaseModel):
return self
def _initialize_default_memories(self) -> None:
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
self._long_term_memory = self._long_term_memory or LongTermMemory()
self._short_term_memory = self._short_term_memory or ShortTermMemory(
crew=self,
embedder_config=self.embedder,
)
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
self._entity_memory = self.entity_memory or EntityMemory(
crew=self, embedder_config=self.embedder
)
@@ -404,8 +409,7 @@ class Crew(FlowTrackable, BaseModel):
raise PydanticCustomError(
"missing_manager_llm_or_manager_agent",
(
"Attribute `manager_llm` or `manager_agent` is required "
"when using hierarchical process."
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process."
),
{},
)
@@ -511,10 +515,9 @@ class Crew(FlowTrackable, BaseModel):
raise PydanticCustomError(
"invalid_async_conditional_task",
(
f"Conditional Task: {task.description}, "
f"cannot be executed asynchronously."
"Conditional Task: {description}, cannot be executed asynchronously."
),
{},
{"description": task.description},
)
return self
@@ -675,21 +678,8 @@ class Crew(FlowTrackable, BaseModel):
inputs: dict[str, Any] | None = None,
) -> CrewOutput | CrewStreamingOutput:
if self.stream:
for agent in self.agents:
if agent.llm is not None:
agent.llm.stream = True
result_holder: list[CrewOutput] = []
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(current_task_info, result_holder)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
enable_agent_streaming(self.agents)
ctx = StreamingContext()
def run_crew() -> None:
"""Execute the crew and capture the result."""
@@ -697,59 +687,28 @@ class Crew(FlowTrackable, BaseModel):
self.stream = False
crew_result = self.kickoff(inputs=inputs)
if isinstance(crew_result, CrewOutput):
result_holder.append(crew_result)
ctx.result_holder.append(crew_result)
except Exception as exc:
signal_error(state, exc)
signal_error(ctx.state, exc)
finally:
self.stream = True
signal_end(state)
signal_end(ctx.state)
streaming_output = CrewStreamingOutput(
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
sync_iterator=create_chunk_generator(
ctx.state, run_crew, ctx.output_holder
)
)
output_holder.append(streaming_output)
ctx.output_holder.append(streaming_output)
return streaming_output
ctx = baggage.set_baggage(
baggage_ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
)
token = attach(ctx)
token = attach(baggage_ctx)
try:
for before_callback in self.before_kickoff_callbacks:
if inputs is None:
inputs = {}
inputs = before_callback(inputs)
crewai_event_bus.emit(
self,
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
)
# Starts the crew to work on its assigned tasks.
self._task_output_handler.reset()
self._logging_color = "bold_purple"
if inputs is not None:
self._inputs = inputs
self._interpolate_inputs(inputs)
self._set_tasks_callbacks()
self._set_allow_crewai_trigger_context_for_first_task()
for agent in self.agents:
agent.crew = self
agent.set_knowledge(crew_embedder=self.embedder)
# TODO: Create an AgentFunctionCalling protocol for future refactoring
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
agent.create_agent_executor()
if self.planning:
self._handle_crew_planning()
inputs = prepare_kickoff(self, inputs)
if self.process == Process.sequential:
result = self._run_sequential_process()
@@ -814,42 +773,27 @@ class Crew(FlowTrackable, BaseModel):
inputs = inputs or {}
if self.stream:
for agent in self.agents:
if agent.llm is not None:
agent.llm.stream = True
result_holder: list[CrewOutput] = []
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
enable_agent_streaming(self.agents)
ctx = StreamingContext(use_async=True)
async def run_crew() -> None:
try:
self.stream = False
result = await asyncio.to_thread(self.kickoff, inputs)
if isinstance(result, CrewOutput):
result_holder.append(result)
ctx.result_holder.append(result)
except Exception as e:
signal_error(state, e, is_async=True)
signal_error(ctx.state, e, is_async=True)
finally:
self.stream = True
signal_end(state, is_async=True)
signal_end(ctx.state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_crew, output_holder
ctx.state, run_crew, ctx.output_holder
)
)
output_holder.append(streaming_output)
ctx.output_holder.append(streaming_output)
return streaming_output
@@ -864,89 +808,207 @@ class Crew(FlowTrackable, BaseModel):
from all crews as they arrive. After iteration, access results via .results
(list of CrewOutput).
"""
crew_copies = [self.copy() for _ in inputs]
async def kickoff_fn(
crew: Crew, input_data: dict[str, Any]
) -> CrewOutput | CrewStreamingOutput:
return await crew.kickoff_async(inputs=input_data)
return await run_for_each_async(self, inputs, kickoff_fn)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> CrewOutput | CrewStreamingOutput:
"""Native async kickoff method using async task execution throughout.
Unlike kickoff_async which wraps sync kickoff in a thread, this method
uses native async/await for all operations including task execution,
memory operations, and knowledge queries.
"""
if self.stream:
result_holder: list[list[CrewOutput]] = [[]]
current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
enable_agent_streaming(self.agents)
ctx = StreamingContext(use_async=True)
state = create_streaming_state(
current_task_info, result_holder, use_async=True
)
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_all_crews() -> None:
"""Run all crew copies and aggregate their streaming outputs."""
async def run_crew() -> None:
try:
streaming_outputs: list[CrewStreamingOutput] = []
for i, crew in enumerate(crew_copies):
streaming = await crew.kickoff_async(inputs=inputs[i])
if isinstance(streaming, CrewStreamingOutput):
streaming_outputs.append(streaming)
async def consume_stream(
stream_output: CrewStreamingOutput,
) -> CrewOutput:
"""Consume stream chunks and forward to parent queue.
Args:
stream_output: The streaming output to consume.
Returns:
The final CrewOutput result.
"""
async for chunk in stream_output:
if state.async_queue is not None and state.loop is not None:
state.loop.call_soon_threadsafe(
state.async_queue.put_nowait, chunk
)
return stream_output.result
crew_results = await asyncio.gather(
*[consume_stream(s) for s in streaming_outputs]
)
result_holder[0] = list(crew_results)
except Exception as e:
signal_error(state, e, is_async=True)
self.stream = False
inner_result = await self.akickoff(inputs)
if isinstance(inner_result, CrewOutput):
ctx.result_holder.append(inner_result)
except Exception as exc:
signal_error(ctx.state, exc, is_async=True)
finally:
signal_end(state, is_async=True)
self.stream = True
signal_end(ctx.state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
state, run_all_crews, output_holder
ctx.state, run_crew, ctx.output_holder
)
)
def set_results_wrapper(result: Any) -> None:
"""Wrap _set_results to match _set_result signature."""
streaming_output._set_results(result)
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
output_holder.append(streaming_output)
ctx.output_holder.append(streaming_output)
return streaming_output
tasks = [
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
]
baggage_ctx = baggage.set_baggage(
"crew_context", CrewContext(id=str(self.id), key=self.key)
)
token = attach(baggage_ctx)
results = await asyncio.gather(*tasks)
try:
inputs = prepare_kickoff(self, inputs)
total_usage_metrics = UsageMetrics()
for crew_copy in crew_copies:
if crew_copy.usage_metrics:
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
self.usage_metrics = total_usage_metrics
if self.process == Process.sequential:
result = await self._arun_sequential_process()
elif self.process == Process.hierarchical:
result = await self._arun_hierarchical_process()
else:
raise NotImplementedError(
f"The process '{self.process}' is not implemented yet."
)
self._task_output_handler.reset()
return list(results)
for after_callback in self.after_kickoff_callbacks:
result = after_callback(result)
self.usage_metrics = self.calculate_usage_metrics()
return result
except Exception as e:
crewai_event_bus.emit(
self,
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
)
raise
finally:
detach(token)
async def akickoff_for_each(
self, inputs: list[dict[str, Any]]
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
"""Native async execution of the Crew's workflow for each input.
Uses native async throughout rather than thread-based async.
If stream=True, returns a single CrewStreamingOutput that yields chunks
from all crews as they arrive.
"""
async def kickoff_fn(
crew: Crew, input_data: dict[str, Any]
) -> CrewOutput | CrewStreamingOutput:
return await crew.akickoff(inputs=input_data)
return await run_for_each_async(self, inputs, kickoff_fn)
async def _arun_sequential_process(self) -> CrewOutput:
"""Executes tasks sequentially using native async and returns the final output."""
return await self._aexecute_tasks(self.tasks)
async def _arun_hierarchical_process(self) -> CrewOutput:
"""Creates and assigns a manager agent to complete the tasks using native async."""
self._create_manager_agent()
return await self._aexecute_tasks(self.tasks)
async def _aexecute_tasks(
self,
tasks: list[Task],
start_index: int | None = 0,
was_replayed: bool = False,
) -> CrewOutput:
"""Executes tasks using native async and returns the final output.
Args:
tasks: List of tasks to execute
start_index: Index to start execution from (for replay)
was_replayed: Whether this is a replayed execution
Returns:
CrewOutput: Final output of the crew
"""
task_outputs: list[TaskOutput] = []
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
exec_data, task_outputs, last_sync_output = prepare_task_execution(
self, task, task_index, start_index, task_outputs, last_sync_output
)
if exec_data.should_skip:
continue
if isinstance(task, ConditionalTask):
skipped_task_output = await self._ahandle_conditional_task(
task, task_outputs, pending_tasks, task_index, was_replayed
)
if skipped_task_output:
task_outputs.append(skipped_task_output)
continue
if task.async_execution:
context = self._get_context(
task, [last_sync_output] if last_sync_output else []
)
async_task = asyncio.create_task(
task.aexecute_sync(
agent=exec_data.agent,
context=context,
tools=exec_data.tools,
)
)
pending_tasks.append((task, async_task, task_index))
else:
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(
pending_tasks, was_replayed
)
pending_tasks.clear()
context = self._get_context(task, task_outputs)
task_output = await task.aexecute_sync(
agent=exec_data.agent,
context=context,
tools=exec_data.tools,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
self._store_execution_log(task, task_output, task_index, was_replayed)
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
return self._create_crew_output(task_outputs)
async def _ahandle_conditional_task(
self,
task: ConditionalTask,
task_outputs: list[TaskOutput],
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
task_index: int,
was_replayed: bool,
) -> TaskOutput | None:
"""Handle conditional task evaluation using native async."""
if pending_tasks:
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
pending_tasks.clear()
return check_conditional_skip(
self, task, task_outputs, task_index, was_replayed
)
async def _aprocess_async_tasks(
self,
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
was_replayed: bool = False,
) -> list[TaskOutput]:
"""Process pending async tasks and return their outputs."""
task_outputs: list[TaskOutput] = []
for future_task, async_task, task_index in pending_tasks:
task_output = await async_task
task_outputs.append(task_output)
self._process_task_result(future_task, task_output)
self._store_execution_log(
future_task, task_output, task_index, was_replayed
)
return task_outputs
def _handle_crew_planning(self) -> None:
"""Handles the Crew planning."""
@@ -1064,33 +1126,11 @@ class Crew(FlowTrackable, BaseModel):
last_sync_output: TaskOutput | None = None
for task_index, task in enumerate(tasks):
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
continue
agent_to_use = self._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
# Determine which tools to use - task tools take precedence over agent tools
tools_for_task = task.tools or agent_to_use.tools or []
# Prepare tools and ensure they're compatible with task execution
tools_for_task = self._prepare_tools(
agent_to_use,
task,
tools_for_task,
exec_data, task_outputs, last_sync_output = prepare_task_execution(
self, task, task_index, start_index, task_outputs, last_sync_output
)
self._log_task_start(task, agent_to_use.role)
if exec_data.should_skip:
continue
if isinstance(task, ConditionalTask):
skipped_task_output = self._handle_conditional_task(
@@ -1105,9 +1145,9 @@ class Crew(FlowTrackable, BaseModel):
task, [last_sync_output] if last_sync_output else []
)
future = task.execute_async(
agent=agent_to_use,
agent=exec_data.agent,
context=context,
tools=tools_for_task,
tools=exec_data.tools,
)
futures.append((task, future, task_index))
else:
@@ -1117,9 +1157,9 @@ class Crew(FlowTrackable, BaseModel):
context = self._get_context(task, task_outputs)
task_output = task.execute_sync(
agent=agent_to_use,
agent=exec_data.agent,
context=context,
tools=tools_for_task,
tools=exec_data.tools,
)
task_outputs.append(task_output)
self._process_task_result(task, task_output)
@@ -1142,19 +1182,9 @@ class Crew(FlowTrackable, BaseModel):
task_outputs = self._process_async_tasks(futures, was_replayed)
futures.clear()
previous_output = task_outputs[-1] if task_outputs else None
if previous_output is not None and not task.should_execute(previous_output):
self._logger.log(
"debug",
f"Skipping conditional task: {task.description}",
color="yellow",
)
skipped_task_output = task.get_skipped_task_output()
if not was_replayed:
self._store_execution_log(task, skipped_task_output, task_index)
return skipped_task_output
return None
return check_conditional_skip(
self, task, task_outputs, task_index, was_replayed
)
def _prepare_tools(
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
@@ -1318,7 +1348,8 @@ class Crew(FlowTrackable, BaseModel):
)
return tools
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
@staticmethod
def _get_context(task: Task, task_outputs: list[TaskOutput]) -> str:
if not task.context:
return ""
@@ -1387,7 +1418,8 @@ class Crew(FlowTrackable, BaseModel):
)
return task_outputs
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
@staticmethod
def _find_task_index(task_id: str, stored_outputs: list[Any]) -> int | None:
return next(
(
index
@@ -1447,6 +1479,16 @@ class Crew(FlowTrackable, BaseModel):
)
return None
async def aquery_knowledge(
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
) -> list[SearchResult] | None:
"""Query the crew's knowledge base for relevant information asynchronously."""
if self.knowledge:
return await self.knowledge.aquery(
query, results_limit=results_limit, score_threshold=score_threshold
)
return None
def fetch_inputs(self) -> set[str]:
"""
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
@@ -1455,7 +1497,7 @@ class Crew(FlowTrackable, BaseModel):
Returns a set of all discovered placeholder names.
"""
placeholder_pattern = re.compile(r"\{(.+?)\}")
placeholder_pattern = re.compile(r"\{(.+?)}")
required_inputs: set[str] = set()
# Scan tasks for inputs
@@ -1703,6 +1745,32 @@ class Crew(FlowTrackable, BaseModel):
self._logger.log("error", error_msg)
raise RuntimeError(error_msg) from e
def _reset_memory_system(
self, system: Any, name: str, reset_fn: Callable[[Any], Any]
) -> None:
"""Reset a single memory system.
Args:
system: The memory system instance to reset.
name: Display name of the memory system for logging.
reset_fn: Function to call to reset the system.
Raises:
RuntimeError: If the reset operation fails.
"""
try:
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
def _reset_all_memories(self) -> None:
"""Reset all available memory systems."""
memory_systems = self._get_memory_systems()
@@ -1710,21 +1778,10 @@ class Crew(FlowTrackable, BaseModel):
for config in memory_systems.values():
if (system := config.get("system")) is not None:
name = config.get("name")
try:
reset_fn: Callable[[Any], Any] = cast(
Callable[[Any], Any], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
reset_fn: Callable[[Any], Any] = cast(
Callable[[Any], Any], config.get("reset")
)
self._reset_memory_system(system, name, reset_fn)
def _reset_specific_memory(self, memory_type: str) -> None:
"""Reset a specific memory system.
@@ -1743,21 +1800,8 @@ class Crew(FlowTrackable, BaseModel):
if system is None:
raise RuntimeError(f"{name} memory system is not initialized")
try:
reset_fn: Callable[[Any], Any] = cast(
Callable[[Any], Any], config.get("reset")
)
reset_fn(system)
self._logger.log(
"info",
f"[Crew ({self.name if self.name else self.id})] "
f"{name} memory has been reset",
)
except Exception as e:
raise RuntimeError(
f"[Crew ({self.name if self.name else self.id})] "
f"Failed to reset {name} memory: {e!s}"
) from e
reset_fn: Callable[[Any], Any] = cast(Callable[[Any], Any], config.get("reset"))
self._reset_memory_system(system, name, reset_fn)
def _get_memory_systems(self) -> dict[str, Any]:
"""Get all available memory systems with their configuration.
@@ -1845,7 +1889,8 @@ class Crew(FlowTrackable, BaseModel):
):
self.tasks[0].allow_crewai_trigger_context = True
def _show_tracing_disabled_message(self) -> None:
@staticmethod
def _show_tracing_disabled_message() -> None:
"""Show a message when tracing is disabled."""
from crewai.events.listeners.tracing.utils import has_user_declined_tracing

View File

@@ -0,0 +1,363 @@
"""Utility functions for crew operations."""
from __future__ import annotations
import asyncio
from collections.abc import Callable, Coroutine, Iterable
from typing import TYPE_CHECKING, Any
from crewai.agents.agent_builder.base_agent import BaseAgent
from crewai.crews.crew_output import CrewOutput
from crewai.rag.embeddings.types import EmbedderConfig
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
from crewai.utilities.streaming import (
StreamingState,
TaskInfo,
create_streaming_state,
)
if TYPE_CHECKING:
from crewai.crew import Crew
def enable_agent_streaming(agents: Iterable[BaseAgent]) -> None:
"""Enable streaming on all agents that have an LLM configured.
Args:
agents: Iterable of agents to enable streaming on.
"""
for agent in agents:
if agent.llm is not None:
agent.llm.stream = True
def setup_agents(
crew: Crew,
agents: Iterable[BaseAgent],
embedder: EmbedderConfig | None,
function_calling_llm: Any,
step_callback: Callable[..., Any] | None,
) -> None:
"""Set up agents for crew execution.
Args:
crew: The crew instance agents belong to.
agents: Iterable of agents to set up.
embedder: Embedder configuration for knowledge.
function_calling_llm: Default function calling LLM for agents.
step_callback: Default step callback for agents.
"""
for agent in agents:
agent.crew = crew
agent.set_knowledge(crew_embedder=embedder)
if not agent.function_calling_llm: # type: ignore[attr-defined]
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
if not agent.step_callback: # type: ignore[attr-defined]
agent.step_callback = step_callback # type: ignore[attr-defined]
agent.create_agent_executor()
class TaskExecutionData:
"""Data container for prepared task execution information."""
def __init__(
self,
agent: BaseAgent | None,
tools: list[Any],
should_skip: bool = False,
) -> None:
"""Initialize task execution data.
Args:
agent: The agent to use for task execution (None if skipped).
tools: Prepared tools for the task.
should_skip: Whether the task should be skipped (replay).
"""
self.agent = agent
self.tools = tools
self.should_skip = should_skip
def prepare_task_execution(
crew: Crew,
task: Any,
task_index: int,
start_index: int | None,
task_outputs: list[Any],
last_sync_output: Any | None,
) -> tuple[TaskExecutionData, list[Any], Any | None]:
"""Prepare a task for execution, handling replay skip logic and agent/tool setup.
Args:
crew: The crew instance.
task: The task to prepare.
task_index: Index of the current task.
start_index: Index to start execution from (for replay).
task_outputs: Current list of task outputs.
last_sync_output: Last synchronous task output.
Returns:
A tuple of (TaskExecutionData or None if skipped, updated task_outputs, updated last_sync_output).
If the task should be skipped, TaskExecutionData will have should_skip=True.
Raises:
ValueError: If no agent is available for the task.
"""
# Handle replay skip
if start_index is not None and task_index < start_index:
if task.output:
if task.async_execution:
task_outputs.append(task.output)
else:
task_outputs = [task.output]
last_sync_output = task.output
return (
TaskExecutionData(agent=None, tools=[], should_skip=True),
task_outputs,
last_sync_output,
)
agent_to_use = crew._get_agent_to_use(task)
if agent_to_use is None:
raise ValueError(
f"No agent available for task: {task.description}. "
f"Ensure that either the task has an assigned agent "
f"or a manager agent is provided."
)
tools_for_task = task.tools or agent_to_use.tools or []
tools_for_task = crew._prepare_tools(
agent_to_use,
task,
tools_for_task,
)
crew._log_task_start(task, agent_to_use.role)
return (
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
task_outputs,
last_sync_output,
)
def check_conditional_skip(
crew: Crew,
task: Any,
task_outputs: list[Any],
task_index: int,
was_replayed: bool,
) -> Any | None:
"""Check if a conditional task should be skipped.
Args:
crew: The crew instance.
task: The conditional task to check.
task_outputs: List of previous task outputs.
task_index: Index of the current task.
was_replayed: Whether this is a replayed execution.
Returns:
The skipped task output if the task should be skipped, None otherwise.
"""
previous_output = task_outputs[-1] if task_outputs else None
if previous_output is not None and not task.should_execute(previous_output):
crew._logger.log(
"debug",
f"Skipping conditional task: {task.description}",
color="yellow",
)
skipped_task_output = task.get_skipped_task_output()
if not was_replayed:
crew._store_execution_log(task, skipped_task_output, task_index)
return skipped_task_output
return None
def prepare_kickoff(crew: Crew, inputs: dict[str, Any] | None) -> dict[str, Any] | None:
"""Prepare crew for kickoff execution.
Handles before callbacks, event emission, task handler reset, input
interpolation, task callbacks, agent setup, and planning.
Args:
crew: The crew instance to prepare.
inputs: Optional input dictionary to pass to the crew.
Returns:
The potentially modified inputs dictionary after before callbacks.
"""
from crewai.events.event_bus import crewai_event_bus
from crewai.events.types.crew_events import CrewKickoffStartedEvent
for before_callback in crew.before_kickoff_callbacks:
if inputs is None:
inputs = {}
inputs = before_callback(inputs)
future = crewai_event_bus.emit(
crew,
CrewKickoffStartedEvent(crew_name=crew.name, inputs=inputs),
)
if future is not None:
try:
future.result()
except Exception: # noqa: S110
pass
crew._task_output_handler.reset()
crew._logging_color = "bold_purple"
if inputs is not None:
crew._inputs = inputs
crew._interpolate_inputs(inputs)
crew._set_tasks_callbacks()
crew._set_allow_crewai_trigger_context_for_first_task()
setup_agents(
crew,
crew.agents,
crew.embedder,
crew.function_calling_llm,
crew.step_callback,
)
if crew.planning:
crew._handle_crew_planning()
return inputs
class StreamingContext:
"""Container for streaming state and holders used during crew execution."""
def __init__(self, use_async: bool = False) -> None:
"""Initialize streaming context.
Args:
use_async: Whether to use async streaming mode.
"""
self.result_holder: list[CrewOutput] = []
self.current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
self.state: StreamingState = create_streaming_state(
self.current_task_info, self.result_holder, use_async=use_async
)
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
class ForEachStreamingContext:
"""Container for streaming state used in for_each crew execution methods."""
def __init__(self) -> None:
"""Initialize for_each streaming context."""
self.result_holder: list[list[CrewOutput]] = [[]]
self.current_task_info: TaskInfo = {
"index": 0,
"name": "",
"id": "",
"agent_role": "",
"agent_id": "",
}
self.state: StreamingState = create_streaming_state(
self.current_task_info, self.result_holder, use_async=True
)
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
async def run_for_each_async(
crew: Crew,
inputs: list[dict[str, Any]],
kickoff_fn: Callable[
[Crew, dict[str, Any]], Coroutine[Any, Any, CrewOutput | CrewStreamingOutput]
],
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
"""Execute crew workflow for each input asynchronously.
Args:
crew: The crew instance to execute.
inputs: List of input dictionaries for each execution.
kickoff_fn: Async function to call for each crew copy (kickoff_async or akickoff).
Returns:
If streaming, a single CrewStreamingOutput that yields chunks from all crews.
Otherwise, a list of CrewOutput results.
"""
from crewai.types.usage_metrics import UsageMetrics
from crewai.utilities.streaming import (
create_async_chunk_generator,
signal_end,
signal_error,
)
crew_copies = [crew.copy() for _ in inputs]
if crew.stream:
ctx = ForEachStreamingContext()
async def run_all_crews() -> None:
try:
streaming_outputs: list[CrewStreamingOutput] = []
for i, crew_copy in enumerate(crew_copies):
streaming = await kickoff_fn(crew_copy, inputs[i])
if isinstance(streaming, CrewStreamingOutput):
streaming_outputs.append(streaming)
async def consume_stream(
stream_output: CrewStreamingOutput,
) -> CrewOutput:
async for chunk in stream_output:
if (
ctx.state.async_queue is not None
and ctx.state.loop is not None
):
ctx.state.loop.call_soon_threadsafe(
ctx.state.async_queue.put_nowait, chunk
)
return stream_output.result
crew_results = await asyncio.gather(
*[consume_stream(s) for s in streaming_outputs]
)
ctx.result_holder[0] = list(crew_results)
except Exception as e:
signal_error(ctx.state, e, is_async=True)
finally:
signal_end(ctx.state, is_async=True)
streaming_output = CrewStreamingOutput(
async_iterator=create_async_chunk_generator(
ctx.state, run_all_crews, ctx.output_holder
)
)
def set_results_wrapper(result: Any) -> None:
streaming_output._set_results(result)
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
ctx.output_holder.append(streaming_output)
return streaming_output
async_tasks: list[asyncio.Task[CrewOutput | CrewStreamingOutput]] = [
asyncio.create_task(kickoff_fn(crew_copy, input_data))
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
]
results = await asyncio.gather(*async_tasks)
total_usage_metrics = UsageMetrics()
for crew_copy in crew_copies:
if crew_copy.usage_metrics:
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
crew.usage_metrics = total_usage_metrics
crew._task_output_handler.reset()
return list(results)

View File

@@ -140,7 +140,9 @@ class EventListener(BaseEventListener):
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
with self._crew_tree_lock:
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
self._telemetry.crew_execution_span(source, event.inputs)
source._execution_span = self._telemetry.crew_execution_span(
source, event.inputs
)
self._crew_tree_lock.notify_all()
@crewai_event_bus.on(CrewKickoffCompletedEvent)

View File

@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
finally:
detach(flow_token)
async def akickoff(
self, inputs: dict[str, Any] | None = None
) -> Any | FlowStreamingOutput:
"""Native async method to start the flow execution. Alias for kickoff_async.
Args:
inputs: Optional dictionary containing input values and/or a state ID for restoration.
Returns:
The final output from the flow, which is the result of the last executed method.
"""
return await self.kickoff_async(inputs)
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
"""Executes a flow's start method and its triggered listeners.

View File

@@ -1,6 +1,6 @@
from __future__ import annotations
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast
from crewai.events.event_listener import event_listener
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
@@ -9,17 +9,22 @@ from crewai.utilities.printer import Printer
if TYPE_CHECKING:
from crewai.agents.crew_agent_executor import CrewAgentExecutor
from crewai.lite_agent import LiteAgent
from crewai.llms.base_llm import BaseLLM
from crewai.utilities.types import LLMMessage
class LLMCallHookContext:
"""Context object passed to LLM call hooks with full executor access.
"""Context object passed to LLM call hooks.
Provides hooks with complete access to the executor state, allowing
Provides hooks with complete access to the execution state, allowing
modification of messages, responses, and executor attributes.
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
Attributes:
executor: Full reference to the CrewAgentExecutor instance
messages: Direct reference to executor.messages (mutable list).
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
messages: Direct reference to messages (mutable list).
Can be modified in both before_llm_call and after_llm_call hooks.
Modifications in after_llm_call hooks persist to the next iteration,
allowing hooks to modify conversation history for subsequent LLM calls.
@@ -27,33 +32,75 @@ class LLMCallHookContext:
Do NOT replace the list (e.g., context.messages = []), as this will break
the executor. Use context.messages.append() or context.messages.extend()
instead of assignment.
agent: Reference to the agent executing the task
task: Reference to the task being executed
crew: Reference to the crew instance
agent: Reference to the agent executing the task (None for direct LLM calls)
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
llm: Reference to the LLM instance
iterations: Current iteration count
iterations: Current iteration count (0 for direct LLM calls)
response: LLM response string (only set for after_llm_call hooks).
Can be modified by returning a new string from after_llm_call hook.
"""
executor: CrewAgentExecutor | LiteAgent | None
messages: list[LLMMessage]
agent: Any
task: Any
crew: Any
llm: BaseLLM | None | str | Any
iterations: int
response: str | None
def __init__(
self,
executor: CrewAgentExecutor,
executor: CrewAgentExecutor | LiteAgent | None = None,
response: str | None = None,
messages: list[LLMMessage] | None = None,
llm: BaseLLM | str | Any | None = None, # TODO: look into
agent: Any | None = None,
task: Any | None = None,
crew: Any | None = None,
) -> None:
"""Initialize hook context with executor reference.
"""Initialize hook context with executor reference or direct parameters.
Args:
executor: The CrewAgentExecutor instance
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
response: Optional response string (for after_llm_call hooks)
messages: Optional messages list (for direct LLM calls when executor is None)
llm: Optional LLM instance (for direct LLM calls when executor is None)
agent: Optional agent reference (for direct LLM calls when executor is None)
task: Optional task reference (for direct LLM calls when executor is None)
crew: Optional crew reference (for direct LLM calls when executor is None)
"""
self.executor = executor
self.messages = executor.messages
self.agent = executor.agent
self.task = executor.task
self.crew = executor.crew
self.llm = executor.llm
self.iterations = executor.iterations
if executor is not None:
# Existing path: extract from executor
self.executor = executor
self.messages = executor.messages
self.llm = executor.llm
self.iterations = executor.iterations
# Handle CrewAgentExecutor vs LiteAgent differences
if hasattr(executor, "agent"):
self.agent = executor.agent
self.task = cast("CrewAgentExecutor", executor).task
self.crew = cast("CrewAgentExecutor", executor).crew
else:
# LiteAgent case - is the agent itself, doesn't have task/crew
self.agent = (
executor.original_agent
if hasattr(executor, "original_agent")
else executor
)
self.task = None
self.crew = None
else:
# New path: direct LLM call with explicit parameters
self.executor = None
self.messages = messages or []
self.llm = llm
self.agent = agent
self.task = task
self.crew = crew
self.iterations = 0
self.response = response
def request_human_input(

View File

@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
sources: list[BaseKnowledgeSource],
embedder: EmbedderConfig | None = None,
storage: KnowledgeStorage | None = None,
**data,
):
**data: object,
) -> None:
super().__init__(**data)
if storage:
self.storage = storage
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
self.storage.reset()
else:
raise ValueError("Storage is not initialized.")
async def aquery(
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
) -> list[SearchResult]:
"""Query across all knowledge sources asynchronously.
Args:
query: List of query strings.
results_limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
The top results matching the query.
Raises:
ValueError: If storage is not initialized.
"""
if self.storage is None:
raise ValueError("Storage is not initialized.")
return await self.storage.asearch(
query,
limit=results_limit,
score_threshold=score_threshold,
)
async def aadd_sources(self) -> None:
"""Add all knowledge sources to storage asynchronously."""
try:
for source in self.sources:
source.storage = self.storage
await source.aadd()
except Exception as e:
raise e
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
if self.storage:
await self.storage.areset()
else:
raise ValueError("Storage is not initialized.")

View File

@@ -1,5 +1,6 @@
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any
from pydantic import Field, field_validator
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
raise ValueError("Either file_path or file_paths must be provided")
return v
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to load content."""
self.safe_file_paths = self._process_file_paths()
self.validate_content()
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
def load_content(self) -> dict[Path, str]:
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
color="red",
)
def _save_documents(self):
def _save_documents(self) -> None:
"""Save the documents to the storage."""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously."""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")
def convert_to_path(self, path: Path | str) -> Path:
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path

View File

@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
]
def _save_documents(self):
"""
Save the documents to the storage.
def _save_documents(self) -> None:
"""Save the documents to the storage.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
self.storage.save(self.chunks)
else:
raise ValueError("No storage found to save documents.")
@abstractmethod
async def aadd(self) -> None:
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
async def _asave_documents(self) -> None:
"""Save the documents to the storage asynchronously.
This method should be called after the chunks and embeddings are generated.
Raises:
ValueError: If no storage is configured.
"""
if self.storage:
await self.storage.asave(self.chunks)
else:
raise ValueError("No storage found to save documents.")

View File

@@ -2,27 +2,24 @@ from __future__ import annotations
from collections.abc import Iterator
from pathlib import Path
from typing import TYPE_CHECKING, Any
from urllib.parse import urlparse
try:
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
InputFormat,
)
from docling.document_converter import ( # type: ignore[import-not-found]
DocumentConverter,
)
from docling.exceptions import ConversionError # type: ignore[import-not-found]
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
HierarchicalChunker,
)
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
DoclingDocument,
)
from docling.datamodel.base_models import InputFormat
from docling.document_converter import DocumentConverter
from docling.exceptions import ConversionError
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
from docling_core.types.doc.document import DoclingDocument
DOCLING_AVAILABLE = True
except ImportError:
DOCLING_AVAILABLE = False
# Provide type stubs for when docling is not available
if TYPE_CHECKING:
from docling.document_converter import DocumentConverter
from docling_core.types.doc.document import DoclingDocument
from pydantic import Field
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
class CrewDoclingSource(BaseKnowledgeSource):
"""Default Source class for converting documents to markdown or json
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
"""Default Source class for converting documents to markdown or json.
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
any additional dependencies and follows the docling package as the source of truth.
"""
def __init__(self, *args, **kwargs):
def __init__(self, *args: Any, **kwargs: Any) -> None:
if not DOCLING_AVAILABLE:
raise ImportError(
"The docling package is required to use CrewDoclingSource. "
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
)
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
self.chunks.extend(list(new_chunks_iterable))
self._save_documents()
async def aadd(self) -> None:
"""Add docling content asynchronously."""
if self.content is None:
return
for doc in self.content:
new_chunks_iterable = self._chunk_doc(doc)
self.chunks.extend(list(new_chunks_iterable))
await self._asave_documents()
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
return [result.document for result in conv_results_iter]

View File

@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add CSV file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,6 @@
from pathlib import Path
from types import ModuleType
from typing import Any
from pydantic import Field, field_validator
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
safe_file_paths: list[Path] = Field(default_factory=list)
@field_validator("file_path", "file_paths", mode="before")
def validate_file_path(cls, v, info): # noqa: N805
@classmethod
def validate_file_path(
cls, v: Path | list[Path] | str | list[str] | None, info: Any
) -> Path | list[Path] | str | list[str] | None:
"""Validate that at least one of file_path or file_paths is provided."""
# Single check if both are None, O(1) instead of nested conditions
if (
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
return [self.convert_to_path(path) for path in path_list]
def validate_content(self):
def validate_content(self) -> None:
"""Validate the paths."""
for path in self.safe_file_paths:
if not path.exists():
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
color="red",
)
def model_post_init(self, _) -> None:
def model_post_init(self, _: Any) -> None:
if self.file_path:
self._logger.log(
"warning",
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
"""Convert a path to a Path object."""
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
def _import_dependencies(self):
def _import_dependencies(self) -> ModuleType:
"""Dynamically import dependencies."""
try:
import pandas as pd # type: ignore[import-untyped,import-not-found]
import pandas as pd # type: ignore[import-untyped]
return pd
return pd # type: ignore[no-any-return]
except ImportError as e:
missing_package = str(e).split()[-1]
raise ImportError(
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add Excel file content asynchronously."""
content_str = ""
for value in self.content.values():
if isinstance(value, dict):
for sheet_value in value.values():
content_str += str(sheet_value) + "\n"
else:
content_str += str(value) + "\n"
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add JSON file content asynchronously."""
content_str = (
str(self.content) if isinstance(self.content, dict) else self.content
)
new_chunks = self._chunk_text(content_str)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,4 +1,5 @@
from pathlib import Path
from types import ModuleType
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
content[path] = text
return content
def _import_pdfplumber(self):
def _import_pdfplumber(self) -> ModuleType:
"""Dynamically import pdfplumber."""
try:
import pdfplumber
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add PDF file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -1,3 +1,5 @@
from typing import Any
from pydantic import Field
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
content: str = Field(...)
collection_name: str | None = Field(default=None)
def model_post_init(self, _):
def model_post_init(self, _: Any) -> None:
"""Post-initialization method to validate content."""
self.validate_content()
def validate_content(self):
def validate_content(self) -> None:
"""Validate string content."""
if not isinstance(self.content, str):
raise ValueError("StringKnowledgeSource only accepts string content")
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add string content asynchronously."""
new_chunks = self._chunk_text(self.content)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
self.chunks.extend(new_chunks)
self._save_documents()
async def aadd(self) -> None:
"""Add text file content asynchronously."""
for text in self.content.values():
new_chunks = self._chunk_text(text)
self.chunks.extend(new_chunks)
await self._asave_documents()
def _chunk_text(self, text: str) -> list[str]:
"""Utility method to split text into chunks."""
return [

View File

@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
) -> list[SearchResult]:
"""Search for documents in the knowledge base."""
@abstractmethod
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously."""
@abstractmethod
def save(self, documents: list[str]) -> None:
"""Save documents to the knowledge base."""
@abstractmethod
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously."""
@abstractmethod
def reset(self) -> None:
"""Reset the knowledge base."""
@abstractmethod
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""

View File

@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
def __init__(
self,
embedder: ProviderSpec
| BaseEmbeddingsProvider
| type[BaseEmbeddingsProvider]
| BaseEmbeddingsProvider[Any]
| type[BaseEmbeddingsProvider[Any]]
| None = None,
collection_name: str | None = None,
) -> None:
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def asearch(
self,
query: list[str],
limit: int = 5,
metadata_filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[SearchResult]:
"""Search for documents in the knowledge base asynchronously.
Args:
query: List of query strings.
limit: Maximum number of results to return.
metadata_filter: Optional metadata filter for the search.
score_threshold: Minimum similarity score for results.
Returns:
List of search results.
"""
try:
if not query:
raise ValueError("Query cannot be empty")
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
query_text = " ".join(query) if len(query) > 1 else query[0]
return await client.asearch(
collection_name=collection_name,
query=query_text,
limit=limit,
metadata_filter=metadata_filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
)
return []
async def asave(self, documents: list[str]) -> None:
"""Save documents to the knowledge base asynchronously.
Args:
documents: List of document strings to save.
"""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.aget_or_create_collection(collection_name=collection_name)
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
await client.aadd_documents(
collection_name=collection_name, documents=rag_documents
)
except Exception as e:
if "dimension mismatch" in str(e).lower():
Logger(verbose=True).log(
"error",
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
"red",
)
raise ValueError(
"Embedding dimension mismatch. Make sure you're using the same embedding model "
"across all operations with this collection."
"Try resetting the collection using `crewai reset-memories -a`"
) from e
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
raise
async def areset(self) -> None:
"""Reset the knowledge base asynchronously."""
try:
client = self._get_client()
collection_name = (
f"knowledge_{self.collection_name}"
if self.collection_name
else "knowledge"
)
await client.adelete_collection(collection_name=collection_name)
except Exception as e:
logging.error(
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
)

View File

@@ -38,6 +38,8 @@ from crewai.events.types.agent_events import (
)
from crewai.events.types.logging_events import AgentLogsExecutionEvent
from crewai.flow.flow_trackable import FlowTrackable
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
from crewai.lite_agent_output import LiteAgentOutput
from crewai.llm import LLM
from crewai.llms.base_llm import BaseLLM
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
_guardrail_retry_count: int = PrivateAttr(default=0)
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
default_factory=get_before_llm_call_hooks
)
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
default_factory=get_after_llm_call_hooks
)
@model_validator(mode="after")
def setup_llm(self) -> Self:
@@ -246,6 +254,26 @@ class LiteAgent(FlowTrackable, BaseModel):
"""Return the original role for compatibility with tool interfaces."""
return self.role
@property
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
"""Get the before_llm_call hooks for this agent."""
return self._before_llm_call_hooks
@property
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
"""Get the after_llm_call hooks for this agent."""
return self._after_llm_call_hooks
@property
def messages(self) -> list[LLMMessage]:
"""Get the messages list for hook context compatibility."""
return self._messages
@property
def iterations(self) -> int:
"""Get the current iteration count for hook context compatibility."""
return self._iterations
def kickoff(
self,
messages: str | list[LLMMessage],
@@ -504,7 +532,7 @@ class LiteAgent(FlowTrackable, BaseModel):
AgentFinish: The final result of the agent execution.
"""
# Execute the agent loop
formatted_answer = None
formatted_answer: AgentAction | AgentFinish | None = None
while not isinstance(formatted_answer, AgentFinish):
try:
if has_reached_max_iterations(self._iterations, self.max_iterations):
@@ -526,6 +554,7 @@ class LiteAgent(FlowTrackable, BaseModel):
callbacks=self._callbacks,
printer=self._printer,
from_agent=self,
executor_context=self,
)
except Exception as e:

View File

@@ -67,6 +67,7 @@ if TYPE_CHECKING:
from crewai.agent.core import Agent
from crewai.llms.hooks.base import BaseInterceptor
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
from crewai.task import Task
from crewai.tools.base_tool import BaseTool
from crewai.utilities.types import LLMMessage
@@ -585,6 +586,7 @@ class LLM(BaseLLM):
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
stream: bool = False,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize LLM instance.
@@ -1642,6 +1644,10 @@ class LLM(BaseLLM):
if message.get("role") == "system":
msg_role: Literal["assistant"] = "assistant"
message["role"] = msg_role
if not self._invoke_before_llm_call_hooks(messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# --- 5) Set up callbacks if provided
with suppress_warnings():
if callbacks and len(callbacks) > 0:
@@ -1651,7 +1657,16 @@ class LLM(BaseLLM):
params = self._prepare_completion_params(messages, tools)
# --- 7) Make the completion call and handle response
if self.stream:
return self._handle_streaming_response(
result = self._handle_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
else:
result = self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
@@ -1660,14 +1675,12 @@ class LLM(BaseLLM):
response_model=response_model,
)
return self._handle_non_streaming_response(
params=params,
callbacks=callbacks,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
if isinstance(result, str):
result = self._invoke_after_llm_call_hooks(
messages, result, from_agent
)
return result
except LLMContextLengthExceededError:
# Re-raise LLMContextLengthExceededError as it should be handled
# by the CrewAgentExecutor._invoke_loop method, which can then decide

View File

@@ -314,7 +314,7 @@ class BaseLLM(ABC):
call_type: LLMCallType,
from_task: Task | None = None,
from_agent: Agent | None = None,
messages: str | list[dict[str, Any]] | None = None,
messages: str | list[LLMMessage] | None = None,
) -> None:
"""Emit LLM call completed event."""
crewai_event_bus.emit(
@@ -586,3 +586,134 @@ class BaseLLM(ABC):
Dictionary with token usage totals
"""
return UsageMetrics(**self._token_usage)
def _invoke_before_llm_call_hooks(
self,
messages: list[LLMMessage],
from_agent: Agent | None = None,
) -> bool:
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
This method should be called by native provider implementations before
making the actual LLM call when from_agent is None (direct calls).
Args:
messages: The messages being sent to the LLM
from_agent: The agent making the call (None for direct calls)
Returns:
True if LLM call should proceed, False if blocked by hook
Example:
>>> # In a native provider's call() method:
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
... messages, from_agent
... ):
... raise ValueError("LLM call blocked by hook")
"""
# Only invoke hooks for direct calls (no agent context)
if from_agent is not None:
return True
from crewai.hooks.llm_hooks import (
LLMCallHookContext,
get_before_llm_call_hooks,
)
from crewai.utilities.printer import Printer
before_hooks = get_before_llm_call_hooks()
if not before_hooks:
return True
hook_context = LLMCallHookContext(
executor=None,
messages=messages,
llm=self,
agent=None,
task=None,
crew=None,
)
printer = Printer()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
printer.print(
content="LLM call blocked by before_llm_call hook",
color="yellow",
)
return False
except Exception as e:
printer.print(
content=f"Error in before_llm_call hook: {e}",
color="yellow",
)
return True
def _invoke_after_llm_call_hooks(
self,
messages: list[LLMMessage],
response: str,
from_agent: Agent | None = None,
) -> str:
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
This method should be called by native provider implementations after
receiving the LLM response when from_agent is None (direct calls).
Args:
messages: The messages that were sent to the LLM
response: The response from the LLM
from_agent: The agent that made the call (None for direct calls)
Returns:
The potentially modified response string
Example:
>>> # In a native provider's call() method:
>>> if from_agent is None and isinstance(result, str):
... result = self._invoke_after_llm_call_hooks(
... messages, result, from_agent
... )
"""
# Only invoke hooks for direct calls (no agent context)
if from_agent is not None or not isinstance(response, str):
return response
from crewai.hooks.llm_hooks import (
LLMCallHookContext,
get_after_llm_call_hooks,
)
from crewai.utilities.printer import Printer
after_hooks = get_after_llm_call_hooks()
if not after_hooks:
return response
hook_context = LLMCallHookContext(
executor=None,
messages=messages,
llm=self,
agent=None,
task=None,
crew=None,
response=response,
)
printer = Printer()
modified_response = response
try:
for hook in after_hooks:
result = hook(hook_context)
if result is not None and isinstance(result, str):
modified_response = result
hook_context.response = modified_response
except Exception as e:
printer.print(
content=f"Error in after_llm_call hook: {e}",
color="yellow",
)
return modified_response

View File

@@ -3,8 +3,9 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any, cast
from typing import TYPE_CHECKING, Any, Literal, cast
from anthropic.types import ThinkingBlock
from pydantic import BaseModel
from crewai.events.types.llm_events import LLMCallType
@@ -22,8 +23,7 @@ if TYPE_CHECKING:
try:
from anthropic import Anthropic, AsyncAnthropic
from anthropic.types import Message
from anthropic.types.tool_use_block import ToolUseBlock
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
import httpx
except ImportError:
raise ImportError(
@@ -31,6 +31,11 @@ except ImportError:
) from None
class AnthropicThinkingConfig(BaseModel):
type: Literal["enabled", "disabled"]
budget_tokens: int | None = None
class AnthropicCompletion(BaseLLM):
"""Anthropic native completion implementation.
@@ -52,6 +57,7 @@ class AnthropicCompletion(BaseLLM):
stream: bool = False,
client_params: dict[str, Any] | None = None,
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
thinking: AnthropicThinkingConfig | None = None,
**kwargs: Any,
):
"""Initialize Anthropic chat completion client.
@@ -97,6 +103,10 @@ class AnthropicCompletion(BaseLLM):
self.top_p = top_p
self.stream = stream
self.stop_sequences = stop_sequences or []
self.thinking = thinking
self.previous_thinking_blocks: list[ThinkingBlock] = []
# Model-specific settings
self.is_claude_3 = "claude-3" in model.lower()
self.supports_tools = True
@property
@@ -187,6 +197,9 @@ class AnthropicCompletion(BaseLLM):
messages
)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, system_message, tools
@@ -323,6 +336,12 @@ class AnthropicCompletion(BaseLLM):
if tools and self.supports_tools:
params["tools"] = self._convert_tools_for_interference(tools)
if self.thinking:
if isinstance(self.thinking, AnthropicThinkingConfig):
params["thinking"] = self.thinking.model_dump()
else:
params["thinking"] = self.thinking
return params
def _convert_tools_for_interference(
@@ -362,6 +381,34 @@ class AnthropicCompletion(BaseLLM):
return anthropic_tools
def _extract_thinking_block(
self, content_block: Any
) -> ThinkingBlock | dict[str, Any] | None:
"""Extract and format thinking block from content block.
Args:
content_block: Content block from Anthropic response
Returns:
Dictionary with thinking block data including signature, or None if not a thinking block
"""
if content_block.type == "thinking":
thinking_block = {
"type": "thinking",
"thinking": content_block.thinking,
}
if hasattr(content_block, "signature"):
thinking_block["signature"] = content_block.signature
return thinking_block
if content_block.type == "redacted_thinking":
redacted_block = {"type": "redacted_thinking"}
if hasattr(content_block, "thinking"):
redacted_block["thinking"] = content_block.thinking
if hasattr(content_block, "signature"):
redacted_block["signature"] = content_block.signature
return redacted_block
return None
def _format_messages_for_anthropic(
self, messages: str | list[LLMMessage]
) -> tuple[list[LLMMessage], str | None]:
@@ -371,6 +418,7 @@ class AnthropicCompletion(BaseLLM):
- System messages are separate from conversation messages
- Messages must alternate between user and assistant
- First message must be from user
- When thinking is enabled, assistant messages must start with thinking blocks
Args:
messages: Input messages
@@ -395,8 +443,29 @@ class AnthropicCompletion(BaseLLM):
system_message = cast(str, content)
else:
role_str = role if role is not None else "user"
content_str = content if content is not None else ""
formatted_messages.append({"role": role_str, "content": content_str})
if isinstance(content, list):
formatted_messages.append({"role": role_str, "content": content})
elif (
role_str == "assistant"
and self.thinking
and self.previous_thinking_blocks
):
structured_content = cast(
list[dict[str, Any]],
[
*self.previous_thinking_blocks,
{"type": "text", "text": content if content else ""},
],
)
formatted_messages.append(
LLMMessage(role=role_str, content=structured_content)
)
else:
content_str = content if content is not None else ""
formatted_messages.append(
LLMMessage(role=role_str, content=content_str)
)
# Ensure first message is from user (Anthropic requirement)
if not formatted_messages:
@@ -446,7 +515,6 @@ class AnthropicCompletion(BaseLLM):
if tool_uses and tool_uses[0].name == "structured_output":
structured_data = tool_uses[0].input
structured_json = json.dumps(structured_data)
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
@@ -474,15 +542,22 @@ class AnthropicCompletion(BaseLLM):
from_agent,
)
# Extract text content
content = ""
thinking_blocks: list[ThinkingBlock] = []
if response.content:
for content_block in response.content:
if hasattr(content_block, "text"):
content += content_block.text
else:
thinking_block = self._extract_thinking_block(content_block)
if thinking_block:
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
@@ -494,7 +569,9 @@ class AnthropicCompletion(BaseLLM):
if usage.get("total_tokens", 0) > 0:
logging.info(f"Anthropic API usage: {usage}")
return content
return self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
def _handle_streaming_completion(
self,
@@ -535,6 +612,16 @@ class AnthropicCompletion(BaseLLM):
final_message: Message = stream.get_final_message()
thinking_blocks: list[ThinkingBlock] = []
if final_message.content:
for content_block in final_message.content:
thinking_block = self._extract_thinking_block(content_block)
if thinking_block:
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
usage = self._extract_anthropic_token_usage(final_message)
self._track_token_usage_internal(usage)
@@ -588,7 +675,52 @@ class AnthropicCompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
def _execute_tools_and_collect_results(
self,
tool_uses: list[ToolUseBlock],
available_functions: dict[str, Any],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> list[dict[str, Any]]:
"""Execute tools and collect results in Anthropic format.
Args:
tool_uses: List of tool use blocks from Claude's response
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
List of tool result dictionaries in Anthropic format
"""
tool_results = []
for tool_use in tool_uses:
function_name = tool_use.name
function_args = tool_use.input
result = self._handle_tool_execution(
function_name=function_name,
function_args=cast(dict[str, Any], function_args),
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
tool_result = {
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": str(result)
if result is not None
else "Tool execution completed",
}
tool_results.append(tool_result)
return tool_results
def _handle_tool_use_conversation(
self,
@@ -607,37 +739,33 @@ class AnthropicCompletion(BaseLLM):
3. We send tool results back to Claude
4. Claude processes results and generates final response
"""
# Execute all requested tools and collect results
tool_results = []
tool_results = self._execute_tools_and_collect_results(
tool_uses, available_functions, from_task, from_agent
)
for tool_use in tool_uses:
function_name = tool_use.name
function_args = tool_use.input
# Execute the tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
# Create tool result in Anthropic format
tool_result = {
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": str(result)
if result is not None
else "Tool execution completed",
}
tool_results.append(tool_result)
# Prepare follow-up conversation with tool results
follow_up_params = params.copy()
# Add Claude's tool use response to conversation
assistant_message = {"role": "assistant", "content": initial_response.content}
assistant_content: list[
ThinkingBlock | ToolUseBlock | TextBlock | dict[str, Any]
] = []
for block in initial_response.content:
thinking_block = self._extract_thinking_block(block)
if thinking_block:
assistant_content.append(thinking_block)
elif block.type == "tool_use":
assistant_content.append(
{
"type": "tool_use",
"id": block.id,
"name": block.name,
"input": block.input,
}
)
elif hasattr(block, "text"):
assistant_content.append({"type": "text", "text": block.text})
assistant_message = {"role": "assistant", "content": assistant_content}
# Add user message with tool results
user_message = {"role": "user", "content": tool_results}
@@ -656,12 +784,20 @@ class AnthropicCompletion(BaseLLM):
follow_up_usage = self._extract_anthropic_token_usage(final_response)
self._track_token_usage_internal(follow_up_usage)
# Extract final text content
final_content = ""
thinking_blocks: list[ThinkingBlock] = []
if final_response.content:
for content_block in final_response.content:
if hasattr(content_block, "text"):
final_content += content_block.text
else:
thinking_block = self._extract_thinking_block(content_block)
if thinking_block:
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
if thinking_blocks:
self.previous_thinking_blocks = thinking_blocks
final_content = self._apply_stop_words(final_content)
@@ -694,7 +830,7 @@ class AnthropicCompletion(BaseLLM):
logging.error(f"Tool follow-up conversation failed: {e}")
# Fallback: return the first tool result if follow-up fails
if tool_results:
return tool_results[0]["content"]
return cast(str, tool_results[0]["content"])
raise e
async def _ahandle_completion(
@@ -887,28 +1023,9 @@ class AnthropicCompletion(BaseLLM):
3. We send tool results back to Claude
4. Claude processes results and generates final response
"""
tool_results = []
for tool_use in tool_uses:
function_name = tool_use.name
function_args = tool_use.input
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
tool_result = {
"type": "tool_result",
"tool_use_id": tool_use.id,
"content": str(result)
if result is not None
else "Tool execution completed",
}
tool_results.append(tool_result)
tool_results = self._execute_tools_and_collect_results(
tool_uses, available_functions, from_task, from_agent
)
follow_up_params = params.copy()
@@ -963,7 +1080,7 @@ class AnthropicCompletion(BaseLLM):
logging.error(f"Tool follow-up conversation failed: {e}")
if tool_results:
return tool_results[0]["content"]
return cast(str, tool_results[0]["content"])
raise e
def supports_function_calling(self) -> bool:
@@ -999,7 +1116,8 @@ class AnthropicCompletion(BaseLLM):
# Default context window size for Claude models
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
@staticmethod
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
"""Extract token usage from Anthropic response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import json
import logging
import os
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, TypedDict
from pydantic import BaseModel
from typing_extensions import Self
@@ -18,7 +18,6 @@ from crewai.utilities.types import LLMMessage
if TYPE_CHECKING:
from crewai.llms.hooks.base import BaseInterceptor
from crewai.tools.base_tool import BaseTool
try:
@@ -31,6 +30,8 @@ try:
from azure.ai.inference.models import (
ChatCompletions,
ChatCompletionsToolCall,
ChatCompletionsToolDefinition,
FunctionDefinition,
JsonSchemaFormat,
StreamingChatCompletionsUpdate,
)
@@ -50,6 +51,24 @@ except ImportError:
) from None
class AzureCompletionParams(TypedDict, total=False):
"""Type definition for Azure chat completion parameters."""
messages: list[LLMMessage]
stream: bool
model_extras: dict[str, Any]
response_format: JsonSchemaFormat
model: str
temperature: float
top_p: float
frequency_penalty: float
presence_penalty: float
max_tokens: int
stop: list[str]
tools: list[ChatCompletionsToolDefinition]
tool_choice: str
class AzureCompletion(BaseLLM):
"""Azure AI Inference native completion implementation.
@@ -156,7 +175,8 @@ class AzureCompletion(BaseLLM):
and "/openai/deployments/" in self.endpoint
)
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
@staticmethod
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
"""Validate and fix Azure endpoint URL format.
Azure OpenAI endpoints should be in the format:
@@ -179,10 +199,75 @@ class AzureCompletion(BaseLLM):
return endpoint
def _handle_api_error(
self,
error: Exception,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None:
"""Handle API errors with appropriate logging and events.
Args:
error: The exception that occurred
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Raises:
The original exception after logging and emitting events
"""
if isinstance(error, HttpResponseError):
if error.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif error.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif error.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = (
f"Azure API HTTP error: {error.status_code} - {error.message}"
)
else:
error_msg = f"Azure API call failed: {error!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise error
def _handle_completion_error(
self,
error: Exception,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> None:
"""Handle completion-specific errors including context length checks.
Args:
error: The exception that occurred
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Raises:
LLMContextLengthExceededError if context window exceeded, otherwise the original exception
"""
if is_context_length_exceeded(error):
logging.error(f"Context window exceeded: {error}")
raise LLMContextLengthExceededError(str(error)) from error
error_msg = f"Azure API call failed: {error!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise error
def call(
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -198,6 +283,7 @@ class AzureCompletion(BaseLLM):
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Response model
Returns:
Chat completion response or tool call result
@@ -216,6 +302,9 @@ class AzureCompletion(BaseLLM):
# Format messages for Azure
formatted_messages = self._format_messages_for_azure(messages)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare completion parameters
completion_params = self._prepare_completion_params(
formatted_messages, tools, response_model
@@ -239,35 +328,13 @@ class AzureCompletion(BaseLLM):
response_model,
)
except HttpResponseError as e:
if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
return self._handle_api_error(e, from_task, from_agent) # type: ignore[func-returns-value]
async def acall(
async def acall( # type: ignore[return]
self,
messages: str | list[LLMMessage],
tools: list[dict[str, BaseTool]] | None = None,
tools: list[dict[str, Any]] | None = None,
callbacks: list[Any] | None = None,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -321,37 +388,15 @@ class AzureCompletion(BaseLLM):
response_model,
)
except HttpResponseError as e:
if e.status_code == 401:
error_msg = "Azure authentication failed. Check your API key."
elif e.status_code == 404:
error_msg = (
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
)
elif e.status_code == 429:
error_msg = "Azure API rate limit exceeded. Please retry later."
else:
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
except Exception as e:
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise
self._handle_api_error(e, from_task, from_agent)
def _prepare_completion_params(
self,
messages: list[LLMMessage],
tools: list[dict[str, Any]] | None = None,
response_model: type[BaseModel] | None = None,
) -> dict[str, Any]:
) -> AzureCompletionParams:
"""Prepare parameters for Azure AI Inference chat completion.
Args:
@@ -362,11 +407,14 @@ class AzureCompletion(BaseLLM):
Returns:
Parameters dictionary for Azure API
"""
params = {
params: AzureCompletionParams = {
"messages": messages,
"stream": self.stream,
}
if self.stream:
params["model_extras"] = {"stream_options": {"include_usage": True}}
if response_model and self.is_openai_model:
model_description = generate_model_description(response_model)
json_schema_info = model_description["json_schema"]
@@ -409,37 +457,42 @@ class AzureCompletion(BaseLLM):
if drop_params and isinstance(additional_drop_params, list):
for drop_param in additional_drop_params:
params.pop(drop_param, None)
if isinstance(drop_param, str):
params.pop(drop_param, None) # type: ignore[misc]
return params
def _convert_tools_for_interference(
def _convert_tools_for_interference( # type: ignore[override]
self, tools: list[dict[str, Any]]
) -> list[dict[str, Any]]:
"""Convert CrewAI tool format to Azure OpenAI function calling format."""
) -> list[ChatCompletionsToolDefinition]:
"""Convert CrewAI tool format to Azure OpenAI function calling format.
Args:
tools: List of CrewAI tool definitions
Returns:
List of Azure ChatCompletionsToolDefinition objects
"""
from crewai.llms.providers.utils.common import safe_tool_conversion
azure_tools = []
azure_tools: list[ChatCompletionsToolDefinition] = []
for tool in tools:
name, description, parameters = safe_tool_conversion(tool, "Azure")
azure_tool = {
"type": "function",
"function": {
"name": name,
"description": description,
},
}
function_def = FunctionDefinition(
name=name,
description=description,
parameters=parameters
if isinstance(parameters, dict)
else dict(parameters)
if parameters
else None,
)
if parameters:
if isinstance(parameters, dict):
azure_tool["function"]["parameters"] = parameters # type: ignore
else:
azure_tool["function"]["parameters"] = dict(parameters)
tool_def = ChatCompletionsToolDefinition(function=function_def)
azure_tools.append(azure_tool)
azure_tools.append(tool_def)
return azure_tools
@@ -468,144 +521,239 @@ class AzureCompletion(BaseLLM):
return azure_messages
def _handle_completion(
def _validate_and_emit_structured_output(
self,
params: dict[str, Any],
available_functions: dict[str, Any] | None = None,
content: str,
response_model: type[BaseModel],
params: AzureCompletionParams,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
# Make API call
) -> str:
"""Validate content against response model and emit completion event.
Args:
content: Response content to validate
response_model: Pydantic model for validation
params: Completion parameters containing messages
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Raises:
ValueError: If validation fails
"""
try:
response: ChatCompletions = self.client.complete(**params)
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
if not response.choices:
raise ValueError("No choices returned from Azure API")
choice = response.choices[0]
message = choice.message
# Extract and track token usage
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and self.is_openai_model:
content = message.content or ""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
raise ValueError(error_msg) from e
# Handle tool calls
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
if isinstance(tool_call, ChatCompletionsToolCall):
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Extract content
content = message.content or ""
# Apply stop words
content = self._apply_stop_words(content)
# Emit completion event and return content
self._emit_call_completed_event(
response=content,
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"Azure API call failed: {e!s}"
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e
raise ValueError(error_msg) from e
return content
def _handle_streaming_completion(
def _process_completion_response(
self,
params: dict[str, Any],
response: ChatCompletions,
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Process completion response with usage tracking, tool execution, and events.
Args:
response: Chat completion response from Azure API
params: Completion parameters containing messages
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Pydantic model for structured output
Returns:
Response content or structured output
"""
if not response.choices:
raise ValueError("No choices returned from Azure API")
choice = response.choices[0]
message = choice.message
# Extract and track token usage
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and self.is_openai_model:
content = message.content or ""
return self._validate_and_emit_structured_output(
content=content,
response_model=response_model,
params=params,
from_task=from_task,
from_agent=from_agent,
)
# Handle tool calls
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
if isinstance(tool_call, ChatCompletionsToolCall):
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
# Execute tool
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
# Extract content
content = message.content or ""
# Apply stop words
content = self._apply_stop_words(content)
# Emit completion event and return content
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
def _handle_completion(
self,
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming chat completion."""
try:
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
return self._process_completion_response(
response=response,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
except Exception as e:
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
def _process_streaming_update(
self,
update: StreamingChatCompletionsUpdate,
full_response: str,
tool_calls: dict[str, dict[str, str]],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Handle streaming chat completion."""
full_response = ""
tool_calls = {}
"""Process a single streaming update chunk.
# Make streaming API call
for update in self.client.complete(**params):
if isinstance(update, StreamingChatCompletionsUpdate):
if update.choices:
choice = update.choices[0]
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
self._emit_stream_chunk_event(
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
)
Args:
update: Streaming update from Azure API
full_response: Accumulated response content
tool_calls: Dictionary of accumulated tool calls
from_task: Task that initiated the call
from_agent: Agent that initiated the call
# Handle tool call streaming
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
Returns:
Updated full_response string
"""
if update.choices:
choice = update.choices[0]
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
self._emit_stream_chunk_event(
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
)
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += (
tool_call.function.arguments
)
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
return full_response
def _finalize_streaming_response(
self,
full_response: str,
tool_calls: dict[str, dict[str, str]],
usage_data: dict[str, int],
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Finalize streaming response with usage tracking, tool execution, and events.
Args:
full_response: The complete streamed response content
tool_calls: Dictionary of tool calls accumulated during streaming
usage_data: Token usage data from the stream
params: Completion parameters containing messages
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Pydantic model for structured output validation
Returns:
Final response content after processing, or structured output
"""
self._track_token_usage_internal(usage_data)
# Handle structured output validation
if response_model and self.is_openai_model:
return self._validate_and_emit_structured_output(
content=full_response,
response_model=response_model,
params=params,
from_task=from_task,
from_agent=from_agent,
)
# Handle completed tool calls
if tool_calls and available_functions:
@@ -642,11 +790,56 @@ class AzureCompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
def _handle_streaming_completion(
self,
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle streaming chat completion."""
full_response = ""
tool_calls: dict[str, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
for update in self.client.complete(**params): # type: ignore[arg-type]
if isinstance(update, StreamingChatCompletionsUpdate):
if update.usage:
usage = update.usage
usage_data = {
"prompt_tokens": usage.prompt_tokens,
"completion_tokens": usage.completion_tokens,
"total_tokens": usage.total_tokens,
}
continue
full_response = self._process_streaming_update(
update=update,
full_response=full_response,
tool_calls=tool_calls,
from_task=from_task,
from_agent=from_agent,
)
return self._finalize_streaming_response(
full_response=full_response,
tool_calls=tool_calls,
usage_data=usage_data,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
async def _ahandle_completion(
self,
params: dict[str, Any],
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
@@ -654,160 +847,64 @@ class AzureCompletion(BaseLLM):
) -> str | Any:
"""Handle non-streaming chat completion asynchronously."""
try:
response: ChatCompletions = await self.async_client.complete(**params)
if not response.choices:
raise ValueError("No choices returned from Azure API")
choice = response.choices[0]
message = choice.message
usage = self._extract_azure_token_usage(response)
self._track_token_usage_internal(usage)
if response_model and self.is_openai_model:
content = message.content or ""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
raise ValueError(error_msg) from e
if message.tool_calls and available_functions:
tool_call = message.tool_calls[0] # Handle first tool call
if isinstance(tool_call, ChatCompletionsToolCall):
function_name = tool_call.function.name
try:
function_args = json.loads(tool_call.function.arguments)
except json.JSONDecodeError as e:
logging.error(f"Failed to parse tool arguments: {e}")
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = message.content or ""
content = self._apply_stop_words(content)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
# Cast params to Any to avoid type checking issues with TypedDict unpacking
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
return self._process_completion_response(
response=response,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
response_model=response_model,
)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
error_msg = f"Azure API call failed: {e!s}"
logging.error(error_msg)
self._emit_call_failed_event(
error=error_msg, from_task=from_task, from_agent=from_agent
)
raise e
return content
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
async def _ahandle_streaming_completion(
self,
params: dict[str, Any],
params: AzureCompletionParams,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
) -> str | Any:
"""Handle streaming chat completion asynchronously."""
full_response = ""
tool_calls = {}
tool_calls: dict[str, dict[str, Any]] = {}
stream = await self.async_client.complete(**params)
async for update in stream:
usage_data = {"total_tokens": 0}
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
async for update in stream: # type: ignore[union-attr]
if isinstance(update, StreamingChatCompletionsUpdate):
if update.choices:
choice = update.choices[0]
if choice.delta and choice.delta.content:
content_delta = choice.delta.content
full_response += content_delta
self._emit_stream_chunk_event(
chunk=content_delta,
from_task=from_task,
from_agent=from_agent,
)
if choice.delta and choice.delta.tool_calls:
for tool_call in choice.delta.tool_calls:
call_id = tool_call.id or "default"
if call_id not in tool_calls:
tool_calls[call_id] = {
"name": "",
"arguments": "",
}
if tool_call.function and tool_call.function.name:
tool_calls[call_id]["name"] = tool_call.function.name
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += (
tool_call.function.arguments
)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
try:
function_args = json.loads(call_data["arguments"])
except json.JSONDecodeError as e:
logging.error(f"Failed to parse streamed tool arguments: {e}")
if hasattr(update, "usage") and update.usage:
usage = update.usage
usage_data = {
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
"completion_tokens": getattr(usage, "completion_tokens", 0),
"total_tokens": getattr(usage, "total_tokens", 0),
}
continue
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
full_response = self._process_streaming_update(
update=update,
full_response=full_response,
tool_calls=tool_calls,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
full_response = self._apply_stop_words(full_response)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
return self._finalize_streaming_response(
full_response=full_response,
tool_calls=tool_calls,
usage_data=usage_data,
params=params,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
response_model=response_model,
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
# Azure OpenAI models support function calling
@@ -851,7 +948,8 @@ class AzureCompletion(BaseLLM):
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_azure_token_usage(self, response: ChatCompletions) -> dict[str, Any]:
@staticmethod
def _extract_azure_token_usage(response: ChatCompletions) -> dict[str, Any]:
"""Extract token usage from Azure response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage

View File

@@ -312,9 +312,14 @@ class BedrockCompletion(BaseLLM):
# Format messages for Converse API
formatted_messages, system_message = self._format_messages_for_converse(
messages # type: ignore[arg-type]
messages
)
if not self._invoke_before_llm_call_hooks(
cast(list[LLMMessage], formatted_messages), from_agent
):
raise ValueError("LLM call blocked by before_llm_call hook")
# Prepare request body
body: BedrockConverseRequestBody = {
"inferenceConfig": self._get_inference_config(),
@@ -356,11 +361,19 @@ class BedrockCompletion(BaseLLM):
if self.stream:
return self._handle_streaming_converse(
formatted_messages, body, available_functions, from_task, from_agent
cast(list[LLMMessage], formatted_messages),
body,
available_functions,
from_task,
from_agent,
)
return self._handle_converse(
formatted_messages, body, available_functions, from_task, from_agent
cast(list[LLMMessage], formatted_messages),
body,
available_functions,
from_task,
from_agent,
)
except Exception as e:
@@ -481,7 +494,7 @@ class BedrockCompletion(BaseLLM):
def _handle_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: Mapping[str, Any] | None = None,
from_task: Any | None = None,
@@ -605,7 +618,11 @@ class BedrockCompletion(BaseLLM):
messages=messages,
)
return text_content
return self._invoke_after_llm_call_hooks(
messages,
text_content,
from_agent,
)
except ClientError as e:
# Handle all AWS ClientError exceptions as per documentation
@@ -662,7 +679,7 @@ class BedrockCompletion(BaseLLM):
def _handle_streaming_converse(
self,
messages: list[dict[str, Any]],
messages: list[LLMMessage],
body: BedrockConverseRequestBody,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -1149,16 +1166,25 @@ class BedrockCompletion(BaseLLM):
messages=messages,
)
return full_response
return self._invoke_after_llm_call_hooks(
messages,
full_response,
from_agent,
)
def _format_messages_for_converse(
self, messages: str | list[dict[str, str]]
self, messages: str | list[LLMMessage]
) -> tuple[list[dict[str, Any]], str | None]:
"""Format messages for Converse API following AWS documentation."""
# Use base class formatting first
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
"""Format messages for Converse API following AWS documentation.
converse_messages = []
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
a different content structure: {"role": str, "content": [{"text": str}]}
rather than the standard {"role": str, "content": str}.
"""
# Use base class formatting first
formatted_messages = self._format_messages(messages)
converse_messages: list[dict[str, Any]] = []
system_message: str | None = None
for message in formatted_messages:

View File

@@ -3,7 +3,7 @@ from __future__ import annotations
import logging
import os
import re
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, cast
from pydantic import BaseModel
@@ -105,6 +105,7 @@ class GeminiCompletion(BaseLLM):
self.stream = stream
self.safety_settings = safety_settings or {}
self.stop_sequences = stop_sequences or []
self.tools: list[dict[str, Any]] | None = None
# Model-specific settings
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
@@ -223,10 +224,11 @@ class GeminiCompletion(BaseLLM):
Args:
messages: Input messages for the chat completion
tools: List of tool/function definitions
callbacks: Callback functions (not used as token counts are handled by the reponse)
callbacks: Callback functions (not used as token counts are handled by the response)
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Response model to use.
Returns:
Chat completion response or tool call result
@@ -246,6 +248,11 @@ class GeminiCompletion(BaseLLM):
messages
)
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
config = self._prepare_generation_config(
system_instruction, tools, response_model
)
@@ -262,7 +269,6 @@ class GeminiCompletion(BaseLLM):
return self._handle_completion(
formatted_content,
system_instruction,
config,
available_functions,
from_task,
@@ -304,6 +310,7 @@ class GeminiCompletion(BaseLLM):
available_functions: Available functions for tool calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Response model to use.
Returns:
Chat completion response or tool call result
@@ -339,7 +346,6 @@ class GeminiCompletion(BaseLLM):
return await self._ahandle_completion(
formatted_content,
system_instruction,
config,
available_functions,
from_task,
@@ -492,35 +498,113 @@ class GeminiCompletion(BaseLLM):
return contents, system_instruction
def _handle_completion(
def _validate_and_emit_structured_output(
self,
content: str,
response_model: type[BaseModel],
messages_for_event: list[LLMMessage],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Validate content against response model and emit completion event.
Args:
content: Response content to validate
response_model: Pydantic model for validation
messages_for_event: Messages to include in event
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Validated and serialized JSON string
Raises:
ValueError: If validation fails
"""
try:
structured_data = response_model.model_validate_json(content)
structured_json = structured_data.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return structured_json
except Exception as e:
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
logging.error(error_msg)
raise ValueError(error_msg) from e
def _finalize_completion_response(
self,
content: str,
contents: list[types.Content],
response_model: type[BaseModel] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
) -> str:
"""Finalize completion response with validation and event emission.
Args:
content: The response content
contents: Original contents for event conversion
response_model: Pydantic model for structured output validation
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Final response content after processing
"""
messages_for_event = self._convert_contents_to_dict(contents)
# Handle structured output validation
if response_model:
return self._validate_and_emit_structured_output(
content=content,
response_model=response_model,
messages_for_event=messages_for_event,
from_task=from_task,
from_agent=from_agent,
)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return self._invoke_after_llm_call_hooks(
messages_for_event, content, from_agent
)
def _process_response_with_tools(
self,
response: GenerateContentResponse,
contents: list[types.Content],
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = self.client.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
)
"""Process response, execute function calls, and finalize completion.
usage = self._extract_token_usage(response)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
self._track_token_usage_internal(usage)
Args:
response: The completion response
contents: Original contents for event conversion
available_functions: Available functions for function calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Pydantic model for structured output validation
Returns:
Final response content or function call result
"""
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
@@ -549,59 +633,90 @@ class GeminiCompletion(BaseLLM):
content = response.text or ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
return self._finalize_completion_response(
content=content,
contents=contents,
response_model=response_model,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return content
def _handle_streaming_completion(
def _process_stream_chunk(
self,
chunk: GenerateContentResponse,
full_response: str,
function_calls: dict[str, dict[str, Any]],
usage_data: dict[str, int],
from_task: Any | None = None,
from_agent: Any | None = None,
) -> tuple[str, dict[str, dict[str, Any]], dict[str, int]]:
"""Process a single streaming chunk.
Args:
chunk: The streaming chunk response
full_response: Accumulated response text
function_calls: Accumulated function calls
usage_data: Accumulated usage data
from_task: Task that initiated the call
from_agent: Agent that initiated the call
Returns:
Tuple of (updated full_response, updated function_calls, updated usage_data)
"""
if chunk.usage_metadata:
usage_data = self._extract_token_usage(chunk)
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
return full_response, function_calls, usage_data
def _finalize_streaming_response(
self,
full_response: str,
function_calls: dict[str, dict[str, Any]],
usage_data: dict[str, int],
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
"""Finalize streaming response with usage tracking, function execution, and events.
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
for chunk in self.client.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
):
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
Args:
full_response: The complete streamed response content
function_calls: Dictionary of function calls accumulated during streaming
usage_data: Token usage data from the stream
contents: Original contents for event conversion
available_functions: Available functions for function calling
from_task: Task that initiated the call
from_agent: Agent that initiated the call
response_model: Pydantic model for structured output validation
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
Returns:
Final response content after processing
"""
self._track_token_usage_internal(usage_data)
# Handle completed function calls
if function_calls and available_functions:
@@ -629,22 +744,95 @@ class GeminiCompletion(BaseLLM):
if result is not None:
return result
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
return self._finalize_completion_response(
content=full_response,
contents=contents,
response_model=response_model,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
)
return full_response
def _handle_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str | Any:
"""Handle non-streaming content generation."""
try:
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
response = self.client.models.generate_content(
model=self.model,
contents=contents_for_api,
config=config,
)
usage = self._extract_token_usage(response)
except Exception as e:
if is_context_length_exceeded(e):
logging.error(f"Context window exceeded: {e}")
raise LLMContextLengthExceededError(str(e)) from e
raise e from e
self._track_token_usage_internal(usage)
return self._process_response_with_tools(
response=response,
contents=contents,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
def _handle_streaming_completion(
self,
contents: list[types.Content],
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
from_agent: Any | None = None,
response_model: type[BaseModel] | None = None,
) -> str:
"""Handle streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
for chunk in self.client.models.generate_content_stream(
model=self.model,
contents=contents_for_api,
config=config,
):
full_response, function_calls, usage_data = self._process_stream_chunk(
chunk=chunk,
full_response=full_response,
function_calls=function_calls,
usage_data=usage_data,
from_task=from_task,
from_agent=from_agent,
)
return self._finalize_streaming_response(
full_response=full_response,
function_calls=function_calls,
usage_data=usage_data,
contents=contents,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
response_model=response_model,
)
async def _ahandle_completion(
self,
contents: list[types.Content],
system_instruction: str | None,
config: types.GenerateContentConfig,
available_functions: dict[str, Any] | None = None,
from_task: Any | None = None,
@@ -670,46 +858,15 @@ class GeminiCompletion(BaseLLM):
self._track_token_usage_internal(usage)
if response.candidates and (self.tools or available_functions):
candidate = response.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
function_name = part.function_call.name
if function_name is None:
continue
function_args = (
dict(part.function_call.args)
if part.function_call.args
else {}
)
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions or {},
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
content = response.text or ""
content = self._apply_stop_words(content)
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=content,
call_type=LLMCallType.LLM_CALL,
return self._process_response_with_tools(
response=response,
contents=contents,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
response_model=response_model,
)
return content
async def _ahandle_streaming_completion(
self,
contents: list[types.Content],
@@ -722,6 +879,7 @@ class GeminiCompletion(BaseLLM):
"""Handle async streaming content generation."""
full_response = ""
function_calls: dict[str, dict[str, Any]] = {}
usage_data = {"total_tokens": 0}
# The API accepts list[Content] but mypy is overly strict about variance
contents_for_api: Any = contents
@@ -731,64 +889,26 @@ class GeminiCompletion(BaseLLM):
config=config,
)
async for chunk in stream:
if chunk.text:
full_response += chunk.text
self._emit_stream_chunk_event(
chunk=chunk.text,
from_task=from_task,
from_agent=from_agent,
)
full_response, function_calls, usage_data = self._process_stream_chunk(
chunk=chunk,
full_response=full_response,
function_calls=function_calls,
usage_data=usage_data,
from_task=from_task,
from_agent=from_agent,
)
if chunk.candidates:
candidate = chunk.candidates[0]
if candidate.content and candidate.content.parts:
for part in candidate.content.parts:
if hasattr(part, "function_call") and part.function_call:
call_id = part.function_call.name or "default"
if call_id not in function_calls:
function_calls[call_id] = {
"name": part.function_call.name,
"args": dict(part.function_call.args)
if part.function_call.args
else {},
}
if function_calls and available_functions:
for call_data in function_calls.values():
function_name = call_data["name"]
function_args = call_data["args"]
# Skip if function_name is None
if not isinstance(function_name, str):
continue
# Ensure function_args is a dict
if not isinstance(function_args, dict):
function_args = {}
result = self._handle_tool_execution(
function_name=function_name,
function_args=function_args,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
)
if result is not None:
return result
messages_for_event = self._convert_contents_to_dict(contents)
self._emit_call_completed_event(
response=full_response,
call_type=LLMCallType.LLM_CALL,
return self._finalize_streaming_response(
full_response=full_response,
function_calls=function_calls,
usage_data=usage_data,
contents=contents,
available_functions=available_functions,
from_task=from_task,
from_agent=from_agent,
messages=messages_for_event,
response_model=response_model,
)
return full_response
def supports_function_calling(self) -> bool:
"""Check if the model supports function calling."""
return self.supports_tools
@@ -848,12 +968,12 @@ class GeminiCompletion(BaseLLM):
}
return {"total_tokens": 0}
@staticmethod
def _convert_contents_to_dict(
self,
contents: list[types.Content],
) -> list[dict[str, str]]:
) -> list[LLMMessage]:
"""Convert contents to dict format."""
result: list[dict[str, str]] = []
result: list[LLMMessage] = []
for content_obj in contents:
role = content_obj.role
if role == "model":
@@ -866,5 +986,10 @@ class GeminiCompletion(BaseLLM):
part.text for part in parts if hasattr(part, "text") and part.text
)
result.append({"role": role, "content": content})
result.append(
LLMMessage(
role=cast(Literal["user", "assistant", "system"], role),
content=content,
)
)
return result

View File

@@ -1,13 +1,14 @@
from __future__ import annotations
from collections.abc import AsyncIterator, Iterator
from collections.abc import AsyncIterator
import json
import logging
import os
from typing import TYPE_CHECKING, Any
import httpx
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
from openai.lib.streaming.chat import ChatCompletionStream
from openai.types.chat import ChatCompletion, ChatCompletionChunk
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import ChoiceDelta
@@ -189,6 +190,9 @@ class OpenAICompletion(BaseLLM):
formatted_messages = self._format_messages(messages)
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
raise ValueError("LLM call blocked by before_llm_call hook")
completion_params = self._prepare_completion_params(
messages=formatted_messages, tools=tools
)
@@ -293,6 +297,7 @@ class OpenAICompletion(BaseLLM):
}
if self.stream:
params["stream"] = self.stream
params["stream_options"] = {"include_usage": True}
params.update(self.additional_params)
@@ -473,6 +478,10 @@ class OpenAICompletion(BaseLLM):
if usage.get("total_tokens", 0) > 0:
logging.info(f"OpenAI API usage: {usage}")
content = self._invoke_after_llm_call_hooks(
params["messages"], content, from_agent
)
except NotFoundError as e:
error_msg = f"Model {self.model} not found: {e}"
logging.error(error_msg)
@@ -515,59 +524,61 @@ class OpenAICompletion(BaseLLM):
tool_calls = {}
if response_model:
completion_stream: Iterator[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
)
parse_params = {
k: v
for k, v in params.items()
if k not in ("response_format", "stream")
}
accumulated_content = ""
for chunk in completion_stream:
if not chunk.choices:
continue
stream: ChatCompletionStream[BaseModel]
with self.client.beta.chat.completions.stream(
**parse_params, response_format=response_model
) as stream:
for chunk in stream:
if chunk.type == "content.delta":
delta_content = chunk.delta
if delta_content:
self._emit_stream_chunk_event(
chunk=delta_content,
from_task=from_task,
from_agent=from_agent,
)
choice = chunk.choices[0]
delta: ChoiceDelta = choice.delta
final_completion = stream.get_final_completion()
if final_completion:
usage = self._extract_openai_token_usage(final_completion)
self._track_token_usage_internal(usage)
if final_completion.choices:
parsed_result = final_completion.choices[0].message.parsed
if parsed_result:
structured_json = parsed_result.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
if delta.content:
accumulated_content += delta.content
self._emit_stream_chunk_event(
chunk=delta.content,
from_task=from_task,
from_agent=from_agent,
)
logging.error("Failed to get parsed result from stream")
return ""
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
self._emit_call_completed_event(
response=structured_json,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return structured_json
except Exception as e:
logging.error(f"Failed to parse structured output from stream: {e}")
self._emit_call_completed_event(
response=accumulated_content,
call_type=LLMCallType.LLM_CALL,
from_task=from_task,
from_agent=from_agent,
messages=params["messages"],
)
return accumulated_content
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
**params
completion_stream: Stream[ChatCompletionChunk] = (
self.client.chat.completions.create(**params)
)
for chunk in stream:
if not chunk.choices:
usage_data = {"total_tokens": 0}
for completion_chunk in completion_stream:
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
usage_data = self._extract_openai_token_usage(completion_chunk)
continue
choice = chunk.choices[0]
if not completion_chunk.choices:
continue
choice = completion_chunk.choices[0]
chunk_delta: ChoiceDelta = choice.delta
if chunk_delta.content:
@@ -592,6 +603,8 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
@@ -635,7 +648,9 @@ class OpenAICompletion(BaseLLM):
messages=params["messages"],
)
return full_response
return self._invoke_after_llm_call_hooks(
params["messages"], full_response, from_agent
)
async def _ahandle_completion(
self,
@@ -782,7 +797,12 @@ class OpenAICompletion(BaseLLM):
] = await self.async_client.chat.completions.create(**params)
accumulated_content = ""
usage_data = {"total_tokens": 0}
async for chunk in completion_stream:
if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
continue
if not chunk.choices:
continue
@@ -797,6 +817,8 @@ class OpenAICompletion(BaseLLM):
from_agent=from_agent,
)
self._track_token_usage_internal(usage_data)
try:
parsed_object = response_model.model_validate_json(accumulated_content)
structured_json = parsed_object.model_dump_json()
@@ -825,7 +847,13 @@ class OpenAICompletion(BaseLLM):
ChatCompletionChunk
] = await self.async_client.chat.completions.create(**params)
usage_data = {"total_tokens": 0}
async for chunk in stream:
if hasattr(chunk, "usage") and chunk.usage:
usage_data = self._extract_openai_token_usage(chunk)
continue
if not chunk.choices:
continue
@@ -854,6 +882,8 @@ class OpenAICompletion(BaseLLM):
if tool_call.function and tool_call.function.arguments:
tool_calls[call_id]["arguments"] += tool_call.function.arguments
self._track_token_usage_internal(usage_data)
if tool_calls and available_functions:
for call_data in tool_calls.values():
function_name = call_data["name"]
@@ -941,8 +971,10 @@ class OpenAICompletion(BaseLLM):
# Default context window size
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion response."""
def _extract_openai_token_usage(
self, response: ChatCompletion | ChatCompletionChunk
) -> dict[str, Any]:
"""Extract token usage from OpenAI ChatCompletion or ChatCompletionChunk response."""
if hasattr(response, "usage") and response.usage:
usage = response.usage
return {

View File

@@ -1,5 +1,6 @@
from __future__ import annotations
import asyncio
from typing import TYPE_CHECKING
from crewai.memory import (
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
class ContextualMemory:
"""Aggregates and retrieves context from multiple memory sources."""
def __init__(
self,
stm: ShortTermMemory,
@@ -46,9 +49,14 @@ class ContextualMemory:
self.exm.task = self.task
def build_context_for_task(self, task: Task, context: str) -> str:
"""
Automatically builds a minimal, highly relevant set of contextual information
for a given task.
"""Build contextual information for a task synchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
@@ -63,6 +71,31 @@ class ContextualMemory:
]
return "\n".join(filter(None, context_parts))
async def abuild_context_for_task(self, task: Task, context: str) -> str:
"""Build contextual information for a task asynchronously.
Args:
task: The task to build context for.
context: Additional context string.
Returns:
Formatted context string from all memory sources.
"""
query = f"{task.description} {context}".strip()
if query == "":
return ""
# Fetch all contexts concurrently
results = await asyncio.gather(
self._afetch_ltm_context(task.description),
self._afetch_stm_context(query),
self._afetch_entity_context(query),
self._afetch_external_context(query),
)
return "\n".join(filter(None, results))
def _fetch_stm_context(self, query: str) -> str:
"""
Fetches recent relevant insights from STM related to the task's description and expected_output,
@@ -135,3 +168,87 @@ class ContextualMemory:
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"
async def _afetch_stm_context(self, query: str) -> str:
"""Fetch recent relevant insights from STM asynchronously.
Args:
query: The search query.
Returns:
Formatted insights as bullet points, or empty string if none found.
"""
if self.stm is None:
return ""
stm_results = await self.stm.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in stm_results]
)
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
async def _afetch_ltm_context(self, task: str) -> str | None:
"""Fetch historical data from LTM asynchronously.
Args:
task: The task description to search for.
Returns:
Formatted historical data as bullet points, or None if none found.
"""
if self.ltm is None:
return ""
ltm_results = await self.ltm.asearch(task, latest_n=2)
if not ltm_results:
return None
formatted_results = [
suggestion
for result in ltm_results
for suggestion in result["metadata"]["suggestions"]
]
formatted_results = list(dict.fromkeys(formatted_results))
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
async def _afetch_entity_context(self, query: str) -> str:
"""Fetch relevant entity information asynchronously.
Args:
query: The search query.
Returns:
Formatted entity information as bullet points, or empty string if none found.
"""
if self.em is None:
return ""
em_results = await self.em.asearch(query)
formatted_results = "\n".join(
[f"- {result['content']}" for result in em_results]
)
return f"Entities:\n{formatted_results}" if em_results else ""
async def _afetch_external_context(self, query: str) -> str:
"""Fetch relevant information from External Memory asynchronously.
Args:
query: The search query.
Returns:
Formatted information as bullet points, or empty string if none found.
"""
if self.exm is None:
return ""
external_memories = await self.exm.asearch(query)
if not external_memories:
return ""
formatted_memories = "\n".join(
f"- {result['content']}" for result in external_memories
)
return f"External memories:\n{formatted_memories}"

View File

@@ -26,7 +26,13 @@ class EntityMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search entity memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
)
raise
async def asave(
self,
value: EntityMemoryItem | list[EntityMemoryItem],
metadata: dict[str, Any] | None = None,
) -> None:
"""Save entity items asynchronously.
Args:
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
metadata: Optional metadata dict (not used, for signature compatibility).
"""
if not value:
return
items = value if isinstance(value, list) else [value]
is_batch = len(items) > 1
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
metadata=metadata,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
saved_count = 0
errors: list[str | None] = []
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
"""Save a single item asynchronously."""
try:
if self._memory_provider == "mem0":
data = f"""
Remember details about the following entity:
Name: {item.name}
Type: {item.type}
Entity Description: {item.description}
"""
else:
data = f"{item.name}({item.type}): {item.description}"
await super(EntityMemory, self).asave(data, item.metadata)
return True, None
except Exception as e:
return False, f"{item.name}: {e!s}"
try:
for item in items:
success, error = await save_single_item(item)
if success:
saved_count += 1
else:
errors.append(error)
if is_batch:
emit_value = f"Saved {saved_count} entities"
metadata = {"entity_count": saved_count, "errors": errors}
else:
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
metadata = items[0].metadata
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=emit_value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
if errors:
raise Exception(
f"Partial save: {len(errors)} failed out of {len(items)}"
)
except Exception as e:
fail_metadata = (
{"entity_count": len(items), "saved": saved_count}
if is_batch
else items[0].metadata
)
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
metadata=fail_metadata,
error=str(e),
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search entity memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="entity_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="entity_memory",
),
)
raise
def reset(self) -> None:
try:
self.storage.reset()

View File

@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
from crewai.memory.storage.mem0_storage import Mem0Storage
return Mem0Storage(type="external", crew=crew, config=config)
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
@staticmethod
def external_supported_storages() -> dict[str, Any]:
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
if provider not in supported_storages:
raise ValueError(f"Provider {provider} not supported")
return supported_storages[provider](crew, embedder_config.get("config", {}))
storage: Storage = supported_storages[provider](
crew, embedder_config.get("config", {})
)
return storage
def save(
self,
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search external memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to external memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ExternalMemoryItem(
value=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
await super().asave(value=item.value, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search external memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await super().asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="external_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="external_memory",
),
)
raise
def reset(self) -> None:
self.storage.reset()

View File

@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
LongTermMemoryItem instances.
"""
def __init__(self, storage=None, path=None):
def __init__(
self,
storage: LTMSQLiteStorage | None = None,
path: str | None = None,
) -> None:
if not storage:
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
super().__init__(storage=storage)
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
self.storage.save(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
)
raise
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
def search( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
) -> list[dict[str, Any]]:
"""Search long-term memory for relevant entries.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
start_time = time.time()
try:
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
results = self.storage.load(task, latest_n)
crewai_event_bus.emit(
self,
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
),
)
return results
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=task,
limit=latest_n,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
"""Save an item to long-term memory asynchronously.
Args:
item: The LongTermMemoryItem to save.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
metadata = item.metadata
metadata.update(
{"agent": item.agent, "expected_output": item.expected_output}
)
await self.storage.asave(
task_description=item.task,
score=metadata["quality"],
metadata=metadata,
datetime=item.datetime,
)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
save_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=item.task,
metadata=item.metadata,
agent_role=item.agent,
error=str(e),
source_type="long_term_memory",
),
)
raise
async def asearch( # type: ignore[override]
self,
task: str,
latest_n: int = 3,
) -> list[dict[str, Any]]:
"""Search long-term memory asynchronously.
Args:
task: The task description to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=task,
limit=latest_n,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.aload(task, latest_n)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=task,
results=results,
limit=latest_n,
query_time_ms=(time.time() - start_time) * 1000,
source_type="long_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return results or []
except Exception as e:
crewai_event_bus.emit(
self,
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
raise
def reset(self) -> None:
"""Reset long-term memory."""
self.storage.reset()

View File

@@ -13,9 +13,7 @@ if TYPE_CHECKING:
class Memory(BaseModel):
"""
Base class for memory, now supporting agent tags and generic metadata.
"""
"""Base class for memory, supporting agent tags and generic metadata."""
embedder_config: EmbedderConfig | dict[str, Any] | None = None
crew: Any | None = None
@@ -52,20 +50,72 @@ class Memory(BaseModel):
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
metadata = metadata or {}
"""Save a value to memory.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
self.storage.save(value, metadata)
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
metadata = metadata or {}
await self.storage.asave(value, metadata)
def search(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
return self.storage.search(
"""Search memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
)
return results
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search memory for relevant entries asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
results: list[Any] = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
return results
def set_crew(self, crew: Any) -> Memory:
"""Set the crew for this memory instance."""
self.crew = crew
return self

View File

@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
_memory_provider: str | None = PrivateAttr()
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
def __init__(
self,
crew: Any = None,
embedder_config: Any = None,
storage: Any = None,
path: str | None = None,
) -> None:
memory_provider = None
if embedder_config and isinstance(embedder_config, dict):
memory_provider = embedder_config.get("provider")
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
if embedder_config and isinstance(embedder_config, dict)
else None
)
storage = Mem0Storage(type="short_term", crew=crew, config=config)
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
else:
storage = (
storage
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
query: str,
limit: int = 5,
score_threshold: float = 0.6,
):
) -> list[Any]:
"""Search short-term memory for relevant entries.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
try:
results = self.storage.search(
query=query, limit=limit, score_threshold=score_threshold
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
)
crewai_event_bus.emit(
self,
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
),
)
return results
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemoryQueryFailedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
error=str(e),
source_type="short_term_memory",
),
)
raise
async def asave(
self,
value: Any,
metadata: dict[str, Any] | None = None,
) -> None:
"""Save a value to short-term memory asynchronously.
Args:
value: The value to save.
metadata: Optional metadata to associate with the value.
"""
crewai_event_bus.emit(
self,
event=MemorySaveStartedEvent(
value=value,
metadata=metadata,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
item = ShortTermMemoryItem(
data=value,
metadata=metadata,
agent=self.agent.role if self.agent else None,
)
if self._memory_provider == "mem0":
item.data = (
f"Remember the following insights from Agent run: {item.data}"
)
await super().asave(value=item.data, metadata=item.metadata)
crewai_event_bus.emit(
self,
event=MemorySaveCompletedEvent(
value=value,
metadata=metadata,
save_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
except Exception as e:
crewai_event_bus.emit(
self,
event=MemorySaveFailedEvent(
value=value,
metadata=metadata,
error=str(e),
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
raise
async def asearch(
self,
query: str,
limit: int = 5,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search short-term memory asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
score_threshold: Minimum similarity score for results.
Returns:
List of matching memory entries.
"""
crewai_event_bus.emit(
self,
event=MemoryQueryStartedEvent(
query=query,
limit=limit,
score_threshold=score_threshold,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
start_time = time.time()
try:
results = await self.storage.asearch(
query=query, limit=limit, score_threshold=score_threshold
)
crewai_event_bus.emit(
self,
event=MemoryQueryCompletedEvent(
query=query,
results=results,
limit=limit,
score_threshold=score_threshold,
query_time_ms=(time.time() - start_time) * 1000,
source_type="short_term_memory",
from_agent=self.agent,
from_task=self.task,
),
)
return list(results)
except Exception as e:
crewai_event_bus.emit(
self,

View File

@@ -3,29 +3,30 @@ from pathlib import Path
import sqlite3
from typing import Any
import aiosqlite
from crewai.utilities import Printer
from crewai.utilities.paths import db_storage_path
class LTMSQLiteStorage:
"""
An updated SQLite storage class for LTM data storage.
"""
"""SQLite storage class for long-term memory data."""
def __init__(self, db_path: str | None = None) -> None:
"""Initialize the SQLite storage.
Args:
db_path: Optional path to the database file.
"""
if db_path is None:
# Get the parent directory of the default db path and create our db file there
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
self.db_path = db_path
self._printer: Printer = Printer()
# Ensure parent directory exists
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
self._initialize_db()
def _initialize_db(self):
"""
Initializes the SQLite database and creates LTM table
"""
def _initialize_db(self) -> None:
"""Initialize the SQLite database and create LTM table."""
try:
with sqlite3.connect(self.db_path) as conn:
cursor = conn.cursor()
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
)
return None
def reset(
self,
) -> None:
def reset(self) -> None:
"""Resets the LTM table with error handling."""
try:
with sqlite3.connect(self.db_path) as conn:
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)
return
async def asave(
self,
task_description: str,
metadata: dict[str, Any],
datetime: str,
score: int | float,
) -> None:
"""Save data to the LTM table asynchronously.
Args:
task_description: Description of the task.
metadata: Metadata associated with the memory.
datetime: Timestamp of the memory.
score: Quality score of the memory.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute(
"""
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
VALUES (?, ?, ?, ?)
""",
(task_description, json.dumps(metadata), datetime, score),
)
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
color="red",
)
async def aload(
self, task_description: str, latest_n: int
) -> list[dict[str, Any]] | None:
"""Query the LTM table by task description asynchronously.
Args:
task_description: Description of the task to search for.
latest_n: Maximum number of results to return.
Returns:
List of matching memory entries or None if error occurs.
"""
try:
async with aiosqlite.connect(self.db_path) as conn:
cursor = await conn.execute(
f"""
SELECT metadata, datetime, score
FROM long_term_memories
WHERE task_description = ?
ORDER BY datetime DESC, score ASC
LIMIT {latest_n}
""", # nosec # noqa: S608
(task_description,),
)
rows = await cursor.fetchall()
if rows:
return [
{
"metadata": json.loads(row[0]),
"datetime": row[1],
"score": row[2],
}
for row in rows
]
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
color="red",
)
return None
async def areset(self) -> None:
"""Reset the LTM table asynchronously."""
try:
async with aiosqlite.connect(self.db_path) as conn:
await conn.execute("DELETE FROM long_term_memories")
await conn.commit()
except aiosqlite.Error as e:
self._printer.print(
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
color="red",
)

View File

@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
return f"{base_path}/{file_name}"
def save(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
)
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
"""Save a value to storage asynchronously.
Args:
value: The value to save.
metadata: Metadata to associate with the value.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
await client.aget_or_create_collection(collection_name=collection_name)
document: BaseRecord = {"content": value}
if metadata:
document["metadata"] = metadata
batch_size = None
if (
self.embedder_config
and isinstance(self.embedder_config, dict)
and "config" in self.embedder_config
):
nested_config = self.embedder_config["config"]
if isinstance(nested_config, dict):
batch_size = nested_config.get("batch_size")
if batch_size is not None:
await client.aadd_documents(
collection_name=collection_name,
documents=[document],
batch_size=cast(int, batch_size),
)
else:
await client.aadd_documents(
collection_name=collection_name, documents=[document]
)
except Exception as e:
logging.error(
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
)
def search(
self,
query: str,
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
)
return []
async def asearch(
self,
query: str,
limit: int = 5,
filter: dict[str, Any] | None = None,
score_threshold: float = 0.6,
) -> list[Any]:
"""Search for matching entries in storage asynchronously.
Args:
query: The search query.
limit: Maximum number of results to return.
filter: Optional metadata filter.
score_threshold: Minimum similarity score for results.
Returns:
List of matching entries.
"""
try:
client = self._get_client()
collection_name = (
f"memory_{self.type}_{self.agents}"
if self.agents
else f"memory_{self.type}"
)
return await client.asearch(
collection_name=collection_name,
query=query,
limit=limit,
metadata_filter=filter,
score_threshold=score_threshold,
)
except Exception as e:
logging.error(
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
)
return []
def reset(self) -> None:
try:
client = self._get_client()

View File

@@ -1,21 +1,35 @@
"""HuggingFace embeddings provider."""
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
HuggingFaceEmbeddingServer,
HuggingFaceEmbeddingFunction,
)
from pydantic import AliasChoices, Field
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
"""HuggingFace embeddings provider."""
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
"""HuggingFace embeddings provider for the HuggingFace Inference API."""
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
default=HuggingFaceEmbeddingServer,
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
default=HuggingFaceEmbeddingFunction,
description="HuggingFace embedding function class",
)
url: str = Field(
description="HuggingFace API URL",
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
api_key: str | None = Field(
default=None,
description="HuggingFace API key",
validation_alias=AliasChoices(
"EMBEDDINGS_HUGGINGFACE_API_KEY",
"HUGGINGFACE_API_KEY",
"HF_TOKEN",
),
)
model_name: str = Field(
default="sentence-transformers/all-MiniLM-L6-v2",
description="Model name to use for embeddings",
validation_alias=AliasChoices(
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
"HUGGINGFACE_MODEL_NAME",
"model",
),
)

View File

@@ -1,6 +1,6 @@
"""Type definitions for HuggingFace embedding providers."""
from typing import Literal
from typing import Annotated, Literal
from typing_extensions import Required, TypedDict
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
class HuggingFaceProviderConfig(TypedDict, total=False):
"""Configuration for HuggingFace provider."""
url: str
api_key: str
model: Annotated[
str, "sentence-transformers/all-MiniLM-L6-v2"
] # alias for model_name for backward compat
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
class HuggingFaceProviderSpec(TypedDict, total=False):

View File

@@ -497,6 +497,107 @@ class Task(BaseModel):
result = self._execute_core(agent, context, tools)
future.set_result(result)
async def aexecute_sync(
self,
agent: BaseAgent | None = None,
context: str | None = None,
tools: list[BaseTool] | None = None,
) -> TaskOutput:
"""Execute the task asynchronously using native async/await."""
return await self._aexecute_core(agent, context, tools)
async def _aexecute_core(
self,
agent: BaseAgent | None,
context: str | None,
tools: list[Any] | None,
) -> TaskOutput:
"""Run the core execution logic of the task asynchronously."""
try:
agent = agent or self.agent
self.agent = agent
if not agent:
raise Exception(
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
)
self.start_time = datetime.datetime.now()
self.prompt_context = context
tools = tools or self.tools or []
self.processed_by_agents.add(agent.role)
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
if not self._guardrails and not self._guardrail:
pydantic_output, json_output = self._export_output(result)
else:
pydantic_output, json_output = None, None
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
for idx, guardrail in enumerate(self._guardrails):
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=guardrail,
guardrail_index=idx,
)
if self._guardrail:
task_output = await self._ainvoke_guardrail_function(
task_output=task_output,
agent=agent,
tools=tools,
guardrail=self._guardrail,
)
self.output = task_output
self.end_time = datetime.datetime.now()
if self.callback:
self.callback(self.output)
crew = self.agent.crew # type: ignore[union-attr]
if crew and crew.task_callback and crew.task_callback != self.callback:
crew.task_callback(self.output)
if self.output_file:
content = (
json_output
if json_output
else (
pydantic_output.model_dump_json() if pydantic_output else result
)
)
self._save_file(content)
crewai_event_bus.emit(
self,
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
)
return task_output
except Exception as e:
self.end_time = datetime.datetime.now()
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
raise e # Re-raise the exception after emitting the event
def _execute_core(
self,
agent: BaseAgent | None,
@@ -539,7 +640,7 @@ class Task(BaseModel):
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
if self._guardrails:
@@ -950,7 +1051,103 @@ Follow these guidelines:
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages,
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output
async def _ainvoke_guardrail_function(
self,
task_output: TaskOutput,
agent: BaseAgent,
tools: list[BaseTool],
guardrail: GuardrailCallable | None,
guardrail_index: int | None = None,
) -> TaskOutput:
"""Invoke the guardrail function asynchronously."""
if not guardrail:
return task_output
if guardrail_index is not None:
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
else:
current_retry_count = self.retry_count
max_attempts = self.guardrail_max_retries + 1
for attempt in range(max_attempts):
guardrail_result = process_guardrail(
output=task_output,
guardrail=guardrail,
retry_count=current_retry_count,
event_source=self,
from_task=self,
from_agent=agent,
)
if guardrail_result.success:
if guardrail_result.result is None:
raise Exception(
"Task guardrail returned None as result. This is not allowed."
)
if isinstance(guardrail_result.result, str):
task_output.raw = guardrail_result.result
pydantic_output, json_output = self._export_output(
guardrail_result.result
)
task_output.pydantic = pydantic_output
task_output.json_dict = json_output
elif isinstance(guardrail_result.result, TaskOutput):
task_output = guardrail_result.result
return task_output
if attempt >= self.guardrail_max_retries:
guardrail_name = (
f"guardrail {guardrail_index}"
if guardrail_index is not None
else "guardrail"
)
raise Exception(
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
f"Last error: {guardrail_result.error}"
)
if guardrail_index is not None:
current_retry_count += 1
self._guardrail_retry_counts[guardrail_index] = current_retry_count
else:
self.retry_count += 1
current_retry_count = self.retry_count
context = self.i18n.errors("validation_error").format(
guardrail_result_error=guardrail_result.error,
task_output=task_output.raw,
)
printer = Printer()
printer.print(
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
color="yellow",
)
result = await agent.aexecute_task(
task=self,
context=context,
tools=tools,
)
pydantic_output, json_output = self._export_output(result)
task_output = TaskOutput(
name=self.name or self.description,
description=self.description,
expected_output=self.expected_output,
raw=result,
pydantic=pydantic_output,
json_dict=json_output,
agent=agent.role,
output_format=self._get_output_format(),
messages=agent.last_messages, # type: ignore[attr-defined]
)
return task_output

View File

@@ -392,9 +392,7 @@ class Telemetry:
self._add_attribute(span, "platform_system", platform.system())
self._add_attribute(span, "platform_version", platform.version())
self._add_attribute(span, "cpus", os.cpu_count())
self._add_attribute(
span, "crew_inputs", json.dumps(inputs) if inputs else None
)
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
else:
self._add_attribute(
span,
@@ -707,9 +705,7 @@ class Telemetry:
self._add_attribute(span, "model_name", model_name)
if crew.share_crew:
self._add_attribute(
span, "inputs", json.dumps(inputs) if inputs else None
)
self._add_attribute(span, "inputs", json.dumps(inputs or {}))
close_span(span)
@@ -814,9 +810,7 @@ class Telemetry:
add_crew_attributes(
span, crew, self._add_attribute, include_fingerprint=False
)
self._add_attribute(
span, "crew_inputs", json.dumps(inputs) if inputs else None
)
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
self._add_attribute(
span,
"crew_agents",

View File

@@ -2,9 +2,18 @@ from __future__ import annotations
from abc import ABC, abstractmethod
import asyncio
from collections.abc import Callable
from collections.abc import Awaitable, Callable
from inspect import signature
from typing import Any, cast, get_args, get_origin
from typing import (
Any,
Generic,
ParamSpec,
TypeVar,
cast,
get_args,
get_origin,
overload,
)
from pydantic import (
BaseModel,
@@ -14,6 +23,7 @@ from pydantic import (
create_model,
field_validator,
)
from typing_extensions import TypeIs
from crewai.tools.structured_tool import CrewStructuredTool
from crewai.utilities.printer import Printer
@@ -21,6 +31,19 @@ from crewai.utilities.printer import Printer
_printer = Printer()
P = ParamSpec("P")
R = TypeVar("R", covariant=True)
def _is_async_callable(func: Callable[..., Any]) -> bool:
"""Check if a callable is async."""
return asyncio.iscoroutinefunction(func)
def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]:
"""Type narrowing check for awaitable values."""
return asyncio.iscoroutine(value) or asyncio.isfuture(value)
class EnvVar(BaseModel):
name: str
@@ -55,7 +78,7 @@ class BaseTool(BaseModel, ABC):
default=False, description="Flag to check if the description has been updated."
)
cache_function: Callable = Field(
cache_function: Callable[..., bool] = Field(
default=lambda _args=None, _result=None: True,
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
)
@@ -123,6 +146,35 @@ class BaseTool(BaseModel, ABC):
return result
async def arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Execute the tool asynchronously.
Args:
*args: Positional arguments to pass to the tool.
**kwargs: Keyword arguments to pass to the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(
self,
*args: Any,
**kwargs: Any,
) -> Any:
"""Async implementation of the tool. Override for async support."""
raise NotImplementedError(
f"{self.__class__.__name__} does not implement _arun. "
"Override _arun for async support or use run() for sync execution."
)
def reset_usage_count(self) -> None:
"""Reset the current usage count to zero."""
self.current_usage_count = 0
@@ -133,7 +185,17 @@ class BaseTool(BaseModel, ABC):
*args: Any,
**kwargs: Any,
) -> Any:
"""Here goes the actual implementation of the tool."""
"""Sync implementation of the tool.
Subclasses must implement this method for synchronous execution.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
def to_structured_tool(self) -> CrewStructuredTool:
"""Convert this tool to a CrewStructuredTool instance."""
@@ -239,21 +301,90 @@ class BaseTool(BaseModel, ABC):
if args:
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
return f"{origin.__name__}[{args_str}]"
return str(f"{origin.__name__}[{args_str}]")
return origin.__name__
return str(origin.__name__)
class Tool(BaseTool):
"""The function that will be executed when the tool is called."""
class Tool(BaseTool, Generic[P, R]):
"""Tool that wraps a callable function.
func: Callable
def _run(self, *args: Any, **kwargs: Any) -> Any:
return self.func(*args, **kwargs)
Type Parameters:
P: ParamSpec capturing the function's parameters.
R: The return type of the function.
"""
func: Callable[P, R | Awaitable[R]]
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool synchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
_printer.print(f"Using Tool: {self.name}", color="cyan")
result = self.func(*args, **kwargs)
if asyncio.iscoroutine(result):
result = asyncio.run(result)
self.current_usage_count += 1
return result # type: ignore[return-value]
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the function execution.
"""
return self.func(*args, **kwargs) # type: ignore[return-value]
async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the tool asynchronously.
Args:
*args: Positional arguments for the tool.
**kwargs: Keyword arguments for the tool.
Returns:
The result of the tool execution.
"""
result = await self._arun(*args, **kwargs)
self.current_usage_count += 1
return result
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
"""Executes the wrapped function asynchronously.
Args:
*args: Positional arguments for the function.
**kwargs: Keyword arguments for the function.
Returns:
The result of the async function execution.
Raises:
NotImplementedError: If the wrapped function is not async.
"""
result = self.func(*args, **kwargs)
if _is_awaitable(result):
return await result
raise NotImplementedError(
f"{self.name} does not have an async function. "
"Use run() for sync execution or provide an async function."
)
@classmethod
def from_langchain(cls, tool: Any) -> Tool:
def from_langchain(cls, tool: Any) -> Tool[..., Any]:
"""Create a Tool instance from a CrewStructuredTool.
This method takes a CrewStructuredTool object and converts it into a
@@ -261,10 +392,10 @@ class Tool(BaseTool):
attribute and infers the argument schema if not explicitly provided.
Args:
tool (Any): The CrewStructuredTool object to be converted.
tool: The CrewStructuredTool object to be converted.
Returns:
Tool: A new Tool instance created from the provided CrewStructuredTool.
A new Tool instance created from the provided CrewStructuredTool.
Raises:
ValueError: If the provided tool does not have a callable 'func' attribute.
@@ -308,37 +439,83 @@ class Tool(BaseTool):
def to_langchain(
tools: list[BaseTool | CrewStructuredTool],
) -> list[CrewStructuredTool]:
"""Convert a list of tools to CrewStructuredTool instances."""
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
P2 = ParamSpec("P2")
R2 = TypeVar("R2")
@overload
def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ...
@overload
def tool(
*args, result_as_answer: bool = False, max_usage_count: int | None = None
) -> Callable:
"""
Decorator to create a tool from a function.
name: str,
/,
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
@overload
def tool(
*,
result_as_answer: bool = ...,
max_usage_count: int | None = ...,
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
def tool(
*args: Callable[P2, R2] | str,
result_as_answer: bool = False,
max_usage_count: int | None = None,
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
"""Decorator to create a Tool from a function.
Can be used in three ways:
1. @tool - decorator without arguments, uses function name
2. @tool("name") - decorator with custom name
3. @tool(result_as_answer=True) - decorator with options
Args:
*args: Positional arguments, either the function to decorate or the tool name.
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
*args: Either the function to decorate or a custom tool name.
result_as_answer: If True, the tool result becomes the final agent answer.
max_usage_count: Maximum times this tool can be used. None means unlimited.
Returns:
A Tool instance.
Example:
@tool
def greet(name: str) -> str:
'''Greet someone.'''
return f"Hello, {name}!"
result = greet.run("World")
"""
def _make_with_name(tool_name: str) -> Callable:
def _make_tool(f: Callable) -> BaseTool:
def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]:
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
if f.__doc__ is None:
raise ValueError("Function must have a docstring")
if f.__annotations__ is None:
func_annotations = getattr(f, "__annotations__", None)
if func_annotations is None:
raise ValueError("Function must have type annotations")
class_name = "".join(tool_name.split()).title()
args_schema = cast(
tool_args_schema = cast(
type[PydanticBaseModel],
type(
class_name,
(PydanticBaseModel,),
{
"__annotations__": {
k: v for k, v in f.__annotations__.items() if k != "return"
k: v for k, v in func_annotations.items() if k != "return"
},
},
),
@@ -348,10 +525,9 @@ def tool(
name=tool_name,
description=f.__doc__,
func=f,
args_schema=args_schema,
args_schema=tool_args_schema,
result_as_answer=result_as_answer,
max_usage_count=max_usage_count,
current_usage_count=0,
)
return _make_tool
@@ -360,4 +536,10 @@ def tool(
return _make_with_name(args[0].__name__)(args[0])
if len(args) == 1 and isinstance(args[0], str):
return _make_with_name(args[0])
if len(args) == 0:
def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]:
return _make_with_name(f.__name__)(f)
return decorator
raise ValueError("Invalid arguments")

View File

@@ -160,6 +160,251 @@ class ToolUsage:
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
async def ause(
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
) -> str:
"""Execute a tool asynchronously.
Args:
calling: The tool calling information.
tool_string: The raw tool string from the agent.
Returns:
The result of the tool execution as a string.
"""
if isinstance(calling, ToolUsageError):
error = calling.message
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
if self.task:
self.task.increment_tools_errors()
return error
try:
tool = self._select_tool(calling.tool_name)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
if (
isinstance(tool, CrewStructuredTool)
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
):
try:
return await self._ause(
tool_string=tool_string, tool=tool, calling=calling
)
except Exception as e:
error = getattr(e, "message", str(e))
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(content=f"\n\n{error}\n", color="red")
return error
return (
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
)
async def _ause(
self,
tool_string: str,
tool: CrewStructuredTool,
calling: ToolCalling | InstructorToolCalling,
) -> str:
"""Internal async tool execution implementation.
Args:
tool_string: The raw tool string from the agent.
tool: The tool to execute.
calling: The tool calling information.
Returns:
The result of the tool execution as a string.
"""
if self._check_tool_repeated_usage(calling=calling):
try:
result = self._i18n.errors("task_repeated_usage").format(
tool_names=self.tools_names
)
self._telemetry.tool_repeated_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if self.agent:
event_data = {
"agent_key": self.agent.key,
"agent_role": self.agent.role,
"tool_name": self.action.tool,
"tool_args": self.action.tool_input,
"tool_class": self.action.tool,
"agent": self.agent,
}
if self.agent.fingerprint: # type: ignore
event_data.update(self.agent.fingerprint) # type: ignore
if self.task:
event_data["task_name"] = self.task.name or self.task.description
event_data["task_id"] = str(self.task.id)
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
started_at = time.time()
from_cache = False
result = None # type: ignore
if self.tools_handler and self.tools_handler.cache:
input_str = ""
if calling.arguments:
if isinstance(calling.arguments, dict):
input_str = json.dumps(calling.arguments)
else:
input_str = str(calling.arguments)
result = self.tools_handler.cache.read(
tool=calling.tool_name, input=input_str
) # type: ignore
from_cache = result is not None
available_tool = next(
(
available_tool
for available_tool in self.tools
if available_tool.name == tool.name
),
None,
)
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
if usage_limit_error:
try:
result = usage_limit_error
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
return self._format_result(result=result)
except Exception:
if self.task:
self.task.increment_tools_errors()
if result is None:
try:
if calling.tool_name in [
"Delegate work to coworker",
"Ask question to coworker",
]:
coworker = (
calling.arguments.get("coworker") if calling.arguments else None
)
if self.task:
self.task.increment_delegations(coworker)
if calling.arguments:
try:
acceptable_args = tool.args_schema.model_json_schema()[
"properties"
].keys()
arguments = {
k: v
for k, v in calling.arguments.items()
if k in acceptable_args
}
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
except Exception:
arguments = calling.arguments
arguments = self._add_fingerprint_metadata(arguments)
result = await tool.ainvoke(input=arguments)
else:
arguments = self._add_fingerprint_metadata({})
result = await tool.ainvoke(input=arguments)
except Exception as e:
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
self._run_attempts += 1
if self._run_attempts > self._max_parsing_attempts:
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
error_message = self._i18n.errors("tool_usage_exception").format(
error=e, tool=tool.name, tool_inputs=tool.description
)
error = ToolUsageError(
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
).message
if self.task:
self.task.increment_tools_errors()
if self.agent and self.agent.verbose:
self._printer.print(
content=f"\n\n{error_message}\n", color="red"
)
return error
if self.task:
self.task.increment_tools_errors()
return await self.ause(calling=calling, tool_string=tool_string)
if self.tools_handler:
should_cache = True
if (
hasattr(available_tool, "cache_function")
and available_tool.cache_function
):
should_cache = available_tool.cache_function(
calling.arguments, result
)
self.tools_handler.on_tool_use(
calling=calling, output=result, should_cache=should_cache
)
self._telemetry.tool_usage(
llm=self.function_calling_llm,
tool_name=tool.name,
attempts=self._run_attempts,
)
result = self._format_result(result=result)
data = {
"result": result,
"tool_name": tool.name,
"tool_args": calling.arguments,
}
self.on_tool_use_finished(
tool=tool,
tool_calling=calling,
from_cache=from_cache,
started_at=started_at,
result=result,
)
if (
hasattr(available_tool, "result_as_answer")
and available_tool.result_as_answer # type: ignore
):
result_as_answer = available_tool.result_as_answer # type: ignore
data["result_as_answer"] = result_as_answer # type: ignore
if self.agent and hasattr(self.agent, "tools_results"):
self.agent.tools_results.append(data)
if available_tool and hasattr(available_tool, "current_usage_count"):
available_tool.current_usage_count += 1
if (
hasattr(available_tool, "max_usage_count")
and available_tool.max_usage_count is not None
):
self._printer.print(
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
color="blue",
)
return result
def _use(
self,
tool_string: str,

View File

@@ -237,22 +237,22 @@ def get_llm_response(
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
executor_context: CrewAgentExecutor | LiteAgent | None = None,
) -> str:
"""Call the LLM and return the response, handling any invalid responses.
Args:
llm: The LLM instance to call
messages: The messages to send to the LLM
callbacks: List of callbacks for the LLM call
printer: Printer instance for output
from_task: Optional task context for the LLM call
from_agent: Optional agent context for the LLM call
response_model: Optional Pydantic model for structured outputs
executor_context: Optional executor context for hook invocation
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
@@ -284,6 +284,60 @@ def get_llm_response(
return _setup_after_llm_call_hooks(executor_context, answer, printer)
async def aget_llm_response(
llm: LLM | BaseLLM,
messages: list[LLMMessage],
callbacks: list[TokenCalcHandler],
printer: Printer,
from_task: Task | None = None,
from_agent: Agent | LiteAgent | None = None,
response_model: type[BaseModel] | None = None,
executor_context: CrewAgentExecutor | None = None,
) -> str:
"""Call the LLM asynchronously and return the response.
Args:
llm: The LLM instance to call.
messages: The messages to send to the LLM.
callbacks: List of callbacks for the LLM call.
printer: Printer instance for output.
from_task: Optional task context for the LLM call.
from_agent: Optional agent context for the LLM call.
response_model: Optional Pydantic model for structured outputs.
executor_context: Optional executor context for hook invocation.
Returns:
The response from the LLM as a string.
Raises:
Exception: If an error occurs.
ValueError: If the response is None or empty.
"""
if executor_context is not None:
if not _setup_before_llm_call_hooks(executor_context, printer):
raise ValueError("LLM call blocked by before_llm_call hook")
messages = executor_context.messages
try:
answer = await llm.acall(
messages,
callbacks=callbacks,
from_task=from_task,
from_agent=from_agent, # type: ignore[arg-type]
response_model=response_model,
)
except Exception as e:
raise e
if not answer:
printer.print(
content="Received None or empty response from LLM call.",
color="red",
)
raise ValueError("Invalid response from LLM call - None or empty.")
return _setup_after_llm_call_hooks(executor_context, answer, printer)
def process_llm_response(
answer: str, use_stop_words: bool
) -> AgentAction | AgentFinish:
@@ -673,7 +727,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
def _setup_before_llm_call_hooks(
executor_context: CrewAgentExecutor | None, printer: Printer
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
) -> bool:
"""Setup and invoke before_llm_call hooks for the executor context.
@@ -723,7 +777,7 @@ def _setup_before_llm_call_hooks(
def _setup_after_llm_call_hooks(
executor_context: CrewAgentExecutor | None,
executor_context: CrewAgentExecutor | LiteAgent | None,
answer: str,
printer: Printer,
) -> str:

View File

@@ -26,6 +26,138 @@ if TYPE_CHECKING:
from crewai.task import Task
async def aexecute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
i18n: I18N,
agent_key: str | None = None,
agent_role: str | None = None,
tools_handler: ToolsHandler | None = None,
task: Task | None = None,
agent: Agent | BaseAgent | None = None,
function_calling_llm: BaseLLM | LLM | None = None,
fingerprint_context: dict[str, str] | None = None,
crew: Crew | None = None,
) -> ToolResult:
"""Execute a tool asynchronously and check if the result should be a final answer.
This is the async version of execute_tool_and_check_finality. It integrates tool
hooks for before and after tool execution, allowing programmatic interception
and modification of tool calls.
Args:
agent_action: The action containing the tool to execute.
tools: List of available tools.
i18n: Internationalization settings.
agent_key: Optional key for event emission.
agent_role: Optional role for event emission.
tools_handler: Optional tools handler for tool execution.
task: Optional task for tool execution.
agent: Optional agent instance for tool execution.
function_calling_llm: Optional LLM for function calling.
fingerprint_context: Optional context for fingerprinting.
crew: Optional crew instance for hook context.
Returns:
ToolResult containing the execution result and whether it should be
treated as a final answer.
"""
logger = Logger(verbose=crew.verbose if crew else False)
tool_name_to_tool_map = {tool.name: tool for tool in tools}
if agent_key and agent_role and agent:
fingerprint_context = fingerprint_context or {}
if agent:
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
if isinstance(fingerprint_context, dict):
try:
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
agent.set_fingerprint(fingerprint=fingerprint_obj)
except Exception as e:
raise ValueError(f"Failed to set fingerprint: {e}") from e
tool_usage = ToolUsage(
tools_handler=tools_handler,
tools=tools,
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
task=task,
agent=agent,
action=agent_action,
)
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
if isinstance(tool_calling, ToolUsageError):
return ToolResult(tool_calling.message, False)
if tool_calling.tool_name.casefold().strip() in [
name.casefold().strip() for name in tool_name_to_tool_map
] or tool_calling.tool_name.casefold().replace("_", " ") in [
name.casefold().strip() for name in tool_name_to_tool_map
]:
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
if not tool:
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([t.name.casefold() for t in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
tool_input = tool_calling.arguments if tool_calling.arguments else {}
hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
)
before_hooks = get_before_tool_call_hooks()
try:
for hook in before_hooks:
result = hook(hook_context)
if result is False:
blocked_message = (
f"Tool execution blocked by hook. "
f"Tool: {tool_calling.tool_name}"
)
return ToolResult(blocked_message, False)
except Exception as e:
logger.log("error", f"Error in before_tool_call hook: {e}")
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
after_hook_context = ToolCallHookContext(
tool_name=tool_calling.tool_name,
tool_input=tool_input,
tool=tool,
agent=agent,
task=task,
crew=crew,
tool_result=tool_result,
)
after_hooks = get_after_tool_call_hooks()
modified_result: str = tool_result
try:
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result
except Exception as e:
logger.log("error", f"Error in after_tool_call hook: {e}")
return ToolResult(modified_result, tool.result_as_answer)
tool_result = i18n.errors("wrong_tool_name").format(
tool=tool_calling.tool_name,
tools=", ".join([tool.name.casefold() for tool in tools]),
)
return ToolResult(result=tool_result, result_as_answer=False)
def execute_tool_and_check_finality(
agent_action: AgentAction,
tools: list[CrewStructuredTool],
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
# Execute after_tool_call hooks
after_hooks = get_after_tool_call_hooks()
modified_result = tool_result
modified_result: str = tool_result
try:
for hook in after_hooks:
hook_result = hook(after_hook_context)
for after_hook in after_hooks:
hook_result = after_hook(after_hook_context)
if hook_result is not None:
modified_result = hook_result
after_hook_context.tool_result = modified_result