mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 23:32:39 +00:00
Merge branch 'main' into gl/fix/task-planning-order
This commit is contained in:
@@ -40,7 +40,7 @@ def _suppress_pydantic_deprecation_warnings() -> None:
|
||||
|
||||
_suppress_pydantic_deprecation_warnings()
|
||||
|
||||
__version__ = "1.6.1"
|
||||
__version__ = "1.7.0"
|
||||
_telemetry_submitted = False
|
||||
|
||||
|
||||
|
||||
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
4
lib/crewai/src/crewai/a2a/extensions/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
"""A2A Protocol Extensions for CrewAI.
|
||||
|
||||
This module contains extensions to the A2A (Agent-to-Agent) protocol.
|
||||
"""
|
||||
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
193
lib/crewai/src/crewai/a2a/extensions/base.py
Normal file
@@ -0,0 +1,193 @@
|
||||
"""Base extension interface for A2A wrapper integrations.
|
||||
|
||||
This module defines the protocol for extending A2A wrapper functionality
|
||||
with custom logic for conversation processing, prompt augmentation, and
|
||||
agent response handling.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
from typing import TYPE_CHECKING, Any, Protocol
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from a2a.types import Message
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
|
||||
|
||||
class ConversationState(Protocol):
|
||||
"""Protocol for extension-specific conversation state.
|
||||
|
||||
Extensions can define their own state classes that implement this protocol
|
||||
to track conversation-specific data extracted from message history.
|
||||
"""
|
||||
|
||||
def is_ready(self) -> bool:
|
||||
"""Check if the state indicates readiness for some action.
|
||||
|
||||
Returns:
|
||||
True if the state is ready, False otherwise.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class A2AExtension(Protocol):
|
||||
"""Protocol for A2A wrapper extensions.
|
||||
|
||||
Extensions can implement this protocol to inject custom logic into
|
||||
the A2A conversation flow at various integration points.
|
||||
"""
|
||||
|
||||
def inject_tools(self, agent: Agent) -> None:
|
||||
"""Inject extension-specific tools into the agent.
|
||||
|
||||
Called when an agent is wrapped with A2A capabilities. Extensions
|
||||
can add tools that enable extension-specific functionality.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract_state_from_history(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> ConversationState | None:
|
||||
"""Extract extension-specific state from conversation history.
|
||||
|
||||
Called during prompt augmentation to allow extensions to analyze
|
||||
the conversation history and extract relevant state information.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Extension-specific conversation state, or None if no relevant state.
|
||||
"""
|
||||
...
|
||||
|
||||
def augment_prompt(
|
||||
self,
|
||||
base_prompt: str,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> str:
|
||||
"""Augment the task prompt with extension-specific instructions.
|
||||
|
||||
Called during prompt augmentation to allow extensions to add
|
||||
custom instructions based on conversation state.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The augmented prompt with extension-specific instructions.
|
||||
"""
|
||||
...
|
||||
|
||||
def process_response(
|
||||
self,
|
||||
agent_response: Any,
|
||||
conversation_state: ConversationState | None,
|
||||
) -> Any:
|
||||
"""Process and potentially modify the agent response.
|
||||
|
||||
Called after parsing the agent's response, allowing extensions to
|
||||
enhance or modify the response based on conversation state.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
conversation_state: Extension-specific state from extract_state_from_history.
|
||||
|
||||
Returns:
|
||||
The processed agent response (may be modified or original).
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class ExtensionRegistry:
|
||||
"""Registry for managing A2A extensions.
|
||||
|
||||
Maintains a collection of extensions and provides methods to invoke
|
||||
their hooks at various integration points.
|
||||
"""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the extension registry."""
|
||||
self._extensions: list[A2AExtension] = []
|
||||
|
||||
def register(self, extension: A2AExtension) -> None:
|
||||
"""Register an extension.
|
||||
|
||||
Args:
|
||||
extension: The extension to register.
|
||||
"""
|
||||
self._extensions.append(extension)
|
||||
|
||||
def inject_all_tools(self, agent: Agent) -> None:
|
||||
"""Inject tools from all registered extensions.
|
||||
|
||||
Args:
|
||||
agent: The agent instance to inject tools into.
|
||||
"""
|
||||
for extension in self._extensions:
|
||||
extension.inject_tools(agent)
|
||||
|
||||
def extract_all_states(
|
||||
self, conversation_history: Sequence[Message]
|
||||
) -> dict[type[A2AExtension], ConversationState]:
|
||||
"""Extract conversation states from all registered extensions.
|
||||
|
||||
Args:
|
||||
conversation_history: The sequence of A2A messages exchanged.
|
||||
|
||||
Returns:
|
||||
Mapping of extension types to their conversation states.
|
||||
"""
|
||||
states: dict[type[A2AExtension], ConversationState] = {}
|
||||
for extension in self._extensions:
|
||||
state = extension.extract_state_from_history(conversation_history)
|
||||
if state is not None:
|
||||
states[type(extension)] = state
|
||||
return states
|
||||
|
||||
def augment_prompt_with_all(
|
||||
self,
|
||||
base_prompt: str,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> str:
|
||||
"""Augment prompt with instructions from all registered extensions.
|
||||
|
||||
Args:
|
||||
base_prompt: The base prompt to augment.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The fully augmented prompt.
|
||||
"""
|
||||
augmented = base_prompt
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
augmented = extension.augment_prompt(augmented, state)
|
||||
return augmented
|
||||
|
||||
def process_response_with_all(
|
||||
self,
|
||||
agent_response: Any,
|
||||
extension_states: dict[type[A2AExtension], ConversationState],
|
||||
) -> Any:
|
||||
"""Process response through all registered extensions.
|
||||
|
||||
Args:
|
||||
agent_response: The parsed agent response.
|
||||
extension_states: Mapping of extension types to conversation states.
|
||||
|
||||
Returns:
|
||||
The processed agent response.
|
||||
"""
|
||||
processed = agent_response
|
||||
for extension in self._extensions:
|
||||
state = extension_states.get(type(extension))
|
||||
processed = extension.process_response(processed, state)
|
||||
return processed
|
||||
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
34
lib/crewai/src/crewai/a2a/extensions/registry.py
Normal file
@@ -0,0 +1,34 @@
|
||||
"""Extension registry factory for A2A configurations.
|
||||
|
||||
This module provides utilities for creating extension registries from A2A configurations.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.a2a.config import A2AConfig
|
||||
|
||||
|
||||
def create_extension_registry_from_config(
|
||||
a2a_config: list[A2AConfig] | A2AConfig,
|
||||
) -> ExtensionRegistry:
|
||||
"""Create an extension registry from A2A configuration.
|
||||
|
||||
Args:
|
||||
a2a_config: A2A configuration (single or list)
|
||||
|
||||
Returns:
|
||||
Configured extension registry with all applicable extensions
|
||||
"""
|
||||
registry = ExtensionRegistry()
|
||||
configs = a2a_config if isinstance(a2a_config, list) else [a2a_config]
|
||||
|
||||
for _ in configs:
|
||||
pass
|
||||
|
||||
return registry
|
||||
@@ -23,6 +23,8 @@ from a2a.types import (
|
||||
TextPart,
|
||||
TransportProtocol,
|
||||
)
|
||||
from aiocache import cached # type: ignore[import-untyped]
|
||||
from aiocache.serializers import PickleSerializer # type: ignore[import-untyped]
|
||||
import httpx
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
@@ -65,7 +67,7 @@ def _fetch_agent_card_cached(
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout
|
||||
_ttl_hash: Time-based hash for cache invalidation (unused in body)
|
||||
_ttl_hash: Time-based hash for cache invalidation
|
||||
|
||||
Returns:
|
||||
Cached AgentCard
|
||||
@@ -106,7 +108,18 @@ def fetch_agent_card(
|
||||
A2AClientHTTPError: If authentication fails
|
||||
"""
|
||||
if use_cache:
|
||||
auth_hash = hash((type(auth).__name__, id(auth))) if auth else 0
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
ttl_hash = int(time.time() // cache_ttl)
|
||||
return _fetch_agent_card_cached(endpoint, auth_hash, timeout, ttl_hash)
|
||||
@@ -121,6 +134,26 @@ def fetch_agent_card(
|
||||
loop.close()
|
||||
|
||||
|
||||
@cached(ttl=300, serializer=PickleSerializer()) # type: ignore[untyped-decorator]
|
||||
async def _fetch_agent_card_async_cached(
|
||||
endpoint: str,
|
||||
auth_hash: int,
|
||||
timeout: int,
|
||||
) -> AgentCard:
|
||||
"""Cached async implementation of AgentCard fetching.
|
||||
|
||||
Args:
|
||||
endpoint: A2A agent endpoint URL
|
||||
auth_hash: Hash of the auth object
|
||||
timeout: Request timeout in seconds
|
||||
|
||||
Returns:
|
||||
Cached AgentCard object
|
||||
"""
|
||||
auth = _auth_store.get(auth_hash)
|
||||
return await _fetch_agent_card_async(endpoint=endpoint, auth=auth, timeout=timeout)
|
||||
|
||||
|
||||
async def _fetch_agent_card_async(
|
||||
endpoint: str,
|
||||
auth: AuthScheme | None,
|
||||
@@ -339,7 +372,22 @@ async def _execute_a2a_delegation_async(
|
||||
Returns:
|
||||
Dictionary with status, result/error, and new history
|
||||
"""
|
||||
agent_card = await _fetch_agent_card_async(endpoint, auth, timeout)
|
||||
if auth:
|
||||
auth_data = auth.model_dump_json(
|
||||
exclude={
|
||||
"_access_token",
|
||||
"_token_expires_at",
|
||||
"_refresh_token",
|
||||
"_authorization_callback",
|
||||
}
|
||||
)
|
||||
auth_hash = hash((type(auth).__name__, auth_data))
|
||||
else:
|
||||
auth_hash = 0
|
||||
_auth_store[auth_hash] = auth
|
||||
agent_card = await _fetch_agent_card_async_cached(
|
||||
endpoint=endpoint, auth_hash=auth_hash, timeout=timeout
|
||||
)
|
||||
|
||||
validate_auth_against_agent_card(agent_card, auth)
|
||||
|
||||
@@ -556,6 +604,34 @@ async def _execute_a2a_delegation_async(
|
||||
}
|
||||
break
|
||||
except Exception as e:
|
||||
if isinstance(e, A2AClientHTTPError):
|
||||
error_msg = f"HTTP Error {e.status_code}: {e!s}"
|
||||
|
||||
error_message = Message(
|
||||
role=Role.agent,
|
||||
message_id=str(uuid.uuid4()),
|
||||
parts=[Part(root=TextPart(text=error_msg))],
|
||||
context_id=context_id,
|
||||
task_id=task_id,
|
||||
)
|
||||
new_messages.append(error_message)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AResponseReceivedEvent(
|
||||
response=error_msg,
|
||||
turn_number=turn_number,
|
||||
is_multiturn=is_multiturn,
|
||||
status="failed",
|
||||
agent_role=agent_role,
|
||||
),
|
||||
)
|
||||
return {
|
||||
"status": "failed",
|
||||
"error": error_msg,
|
||||
"history": new_messages,
|
||||
}
|
||||
|
||||
current_exception: Exception | BaseException | None = e
|
||||
while current_exception:
|
||||
if hasattr(current_exception, "response"):
|
||||
@@ -752,4 +828,5 @@ def get_a2a_agents_and_response_model(
|
||||
Tuple of A2A agent IDs and response model
|
||||
"""
|
||||
a2a_agents, agent_ids = extract_a2a_agent_ids_from_config(a2a_config=a2a_config)
|
||||
|
||||
return a2a_agents, create_agent_response_model(agent_ids)
|
||||
|
||||
@@ -15,6 +15,7 @@ from a2a.types import Role
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.a2a.extensions.base import ExtensionRegistry
|
||||
from crewai.a2a.templates import (
|
||||
AVAILABLE_AGENTS_TEMPLATE,
|
||||
CONVERSATION_TURN_INFO_TEMPLATE,
|
||||
@@ -42,7 +43,9 @@ if TYPE_CHECKING:
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
def wrap_agent_with_a2a_instance(
|
||||
agent: Agent, extension_registry: ExtensionRegistry | None = None
|
||||
) -> None:
|
||||
"""Wrap an agent instance's execute_task method with A2A support.
|
||||
|
||||
This function modifies the agent instance by wrapping its execute_task
|
||||
@@ -51,7 +54,13 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
|
||||
Args:
|
||||
agent: The agent instance to wrap
|
||||
extension_registry: Optional registry of A2A extensions for injecting tools and custom logic
|
||||
"""
|
||||
if extension_registry is None:
|
||||
extension_registry = ExtensionRegistry()
|
||||
|
||||
extension_registry.inject_all_tools(agent)
|
||||
|
||||
original_execute_task = agent.execute_task.__func__ # type: ignore[attr-defined]
|
||||
|
||||
@wraps(original_execute_task)
|
||||
@@ -85,6 +94,7 @@ def wrap_agent_with_a2a_instance(agent: Agent) -> None:
|
||||
agent_response_model=agent_response_model,
|
||||
context=context,
|
||||
tools=tools,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
|
||||
object.__setattr__(agent, "execute_task", MethodType(execute_task_with_a2a, agent))
|
||||
@@ -154,6 +164,7 @@ def _execute_task_with_a2a(
|
||||
agent_response_model: type[BaseModel],
|
||||
context: str | None,
|
||||
tools: list[BaseTool] | None,
|
||||
extension_registry: ExtensionRegistry,
|
||||
) -> str:
|
||||
"""Wrap execute_task with A2A delegation logic.
|
||||
|
||||
@@ -165,6 +176,7 @@ def _execute_task_with_a2a(
|
||||
context: Optional context for task execution
|
||||
tools: Optional tools available to the agent
|
||||
agent_response_model: Optional agent response model
|
||||
extension_registry: Registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Task execution result (either from LLM or A2A agent)
|
||||
@@ -190,11 +202,12 @@ def _execute_task_with_a2a(
|
||||
finally:
|
||||
task.description = original_description
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, _ = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_description,
|
||||
agent_cards=agent_cards,
|
||||
failed_agents=failed_agents,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
task.response_model = agent_response_model
|
||||
|
||||
@@ -204,6 +217,11 @@ def _execute_task_with_a2a(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
|
||||
if extension_registry and isinstance(agent_response, BaseModel):
|
||||
agent_response = extension_registry.process_response_with_all(
|
||||
agent_response, {}
|
||||
)
|
||||
|
||||
if isinstance(agent_response, BaseModel) and isinstance(
|
||||
agent_response, AgentResponseProtocol
|
||||
):
|
||||
@@ -217,6 +235,7 @@ def _execute_task_with_a2a(
|
||||
tools=tools,
|
||||
agent_cards=agent_cards,
|
||||
original_task_description=original_description,
|
||||
extension_registry=extension_registry,
|
||||
)
|
||||
return str(agent_response.message)
|
||||
|
||||
@@ -235,7 +254,8 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: int = 0,
|
||||
max_turns: int | None = None,
|
||||
failed_agents: dict[str, str] | None = None,
|
||||
) -> str:
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> tuple[str, bool]:
|
||||
"""Add A2A delegation instructions to prompt.
|
||||
|
||||
Args:
|
||||
@@ -246,13 +266,14 @@ def _augment_prompt_with_a2a(
|
||||
turn_num: Current turn number (0-indexed)
|
||||
max_turns: Maximum allowed turns (from config)
|
||||
failed_agents: Dictionary mapping failed agent endpoints to error messages
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Augmented task description with A2A instructions
|
||||
Tuple of (augmented prompt, disable_structured_output flag)
|
||||
"""
|
||||
|
||||
if not agent_cards:
|
||||
return task_description
|
||||
return task_description, False
|
||||
|
||||
agents_text = ""
|
||||
|
||||
@@ -270,6 +291,7 @@ def _augment_prompt_with_a2a(
|
||||
agents_text = AVAILABLE_AGENTS_TEMPLATE.substitute(available_a2a_agents=agents_text)
|
||||
|
||||
history_text = ""
|
||||
|
||||
if conversation_history:
|
||||
for msg in conversation_history:
|
||||
history_text += f"\n{msg.model_dump_json(indent=2, exclude_none=True, exclude={'message_id'})}\n"
|
||||
@@ -277,6 +299,15 @@ def _augment_prompt_with_a2a(
|
||||
history_text = PREVIOUS_A2A_CONVERSATION_TEMPLATE.substitute(
|
||||
previous_a2a_conversation=history_text
|
||||
)
|
||||
|
||||
extension_states = {}
|
||||
disable_structured_output = False
|
||||
if extension_registry and conversation_history:
|
||||
extension_states = extension_registry.extract_all_states(conversation_history)
|
||||
for state in extension_states.values():
|
||||
if state.is_ready():
|
||||
disable_structured_output = True
|
||||
break
|
||||
turn_info = ""
|
||||
|
||||
if max_turns is not None and conversation_history:
|
||||
@@ -296,16 +327,22 @@ def _augment_prompt_with_a2a(
|
||||
warning=warning,
|
||||
)
|
||||
|
||||
return f"""{task_description}
|
||||
augmented_prompt = f"""{task_description}
|
||||
|
||||
IMPORTANT: You have the ability to delegate this task to remote A2A agents.
|
||||
|
||||
{agents_text}
|
||||
{history_text}{turn_info}
|
||||
|
||||
|
||||
"""
|
||||
|
||||
if extension_registry:
|
||||
augmented_prompt = extension_registry.augment_prompt_with_all(
|
||||
augmented_prompt, extension_states
|
||||
)
|
||||
|
||||
return augmented_prompt, disable_structured_output
|
||||
|
||||
|
||||
def _parse_agent_response(
|
||||
raw_result: str | dict[str, Any], agent_response_model: type[BaseModel]
|
||||
@@ -373,7 +410,7 @@ def _handle_agent_response_and_continue(
|
||||
if "agent_card" in a2a_result and agent_id not in agent_cards_dict:
|
||||
agent_cards_dict[agent_id] = a2a_result["agent_card"]
|
||||
|
||||
task.description = _augment_prompt_with_a2a(
|
||||
task.description, disable_structured_output = _augment_prompt_with_a2a(
|
||||
a2a_agents=a2a_agents,
|
||||
task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
@@ -382,7 +419,38 @@ def _handle_agent_response_and_continue(
|
||||
agent_cards=agent_cards_dict,
|
||||
)
|
||||
|
||||
original_response_model = task.response_model
|
||||
if disable_structured_output:
|
||||
task.response_model = None
|
||||
|
||||
raw_result = original_fn(self, task, context, tools)
|
||||
|
||||
if disable_structured_output:
|
||||
task.response_model = original_response_model
|
||||
|
||||
if disable_structured_output:
|
||||
final_turn_number = turn_num + 1
|
||||
result_text = str(raw_result)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AMessageSentEvent(
|
||||
message=result_text,
|
||||
turn_number=final_turn_number,
|
||||
is_multiturn=True,
|
||||
agent_role=self.role,
|
||||
),
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
status="completed",
|
||||
final_result=result_text,
|
||||
error=None,
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text, None
|
||||
|
||||
llm_response = _parse_agent_response(
|
||||
raw_result=raw_result, agent_response_model=agent_response_model
|
||||
)
|
||||
@@ -425,6 +493,7 @@ def _delegate_to_a2a(
|
||||
tools: list[BaseTool] | None,
|
||||
agent_cards: dict[str, AgentCard] | None = None,
|
||||
original_task_description: str | None = None,
|
||||
extension_registry: ExtensionRegistry | None = None,
|
||||
) -> str:
|
||||
"""Delegate to A2A agent with multi-turn conversation support.
|
||||
|
||||
@@ -437,6 +506,7 @@ def _delegate_to_a2a(
|
||||
tools: Optional tools available to the agent
|
||||
agent_cards: Pre-fetched agent cards from _execute_task_with_a2a
|
||||
original_task_description: The original task description before A2A augmentation
|
||||
extension_registry: Optional registry of A2A extensions
|
||||
|
||||
Returns:
|
||||
Result from A2A agent
|
||||
@@ -447,9 +517,13 @@ def _delegate_to_a2a(
|
||||
a2a_agents, agent_response_model = get_a2a_agents_and_response_model(self.a2a)
|
||||
agent_ids = tuple(config.endpoint for config in a2a_agents)
|
||||
current_request = str(agent_response.message)
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
|
||||
if agent_id not in agent_ids:
|
||||
if hasattr(agent_response, "a2a_ids") and agent_response.a2a_ids:
|
||||
agent_id = agent_response.a2a_ids[0]
|
||||
else:
|
||||
agent_id = agent_ids[0] if agent_ids else ""
|
||||
|
||||
if agent_id and agent_id not in agent_ids:
|
||||
raise ValueError(
|
||||
f"Unknown A2A agent ID(s): {agent_response.a2a_ids} not in {agent_ids}"
|
||||
)
|
||||
@@ -458,10 +532,11 @@ def _delegate_to_a2a(
|
||||
task_config = task.config or {}
|
||||
context_id = task_config.get("context_id")
|
||||
task_id_config = task_config.get("task_id")
|
||||
reference_task_ids = task_config.get("reference_task_ids")
|
||||
metadata = task_config.get("metadata")
|
||||
extensions = task_config.get("extensions")
|
||||
|
||||
reference_task_ids = task_config.get("reference_task_ids", [])
|
||||
|
||||
if original_task_description is None:
|
||||
original_task_description = task.description
|
||||
|
||||
@@ -497,11 +572,27 @@ def _delegate_to_a2a(
|
||||
|
||||
conversation_history = a2a_result.get("history", [])
|
||||
|
||||
if conversation_history:
|
||||
latest_message = conversation_history[-1]
|
||||
if latest_message.task_id is not None:
|
||||
task_id_config = latest_message.task_id
|
||||
if latest_message.context_id is not None:
|
||||
context_id = latest_message.context_id
|
||||
|
||||
if a2a_result["status"] in ["completed", "input_required"]:
|
||||
if (
|
||||
a2a_result["status"] == "completed"
|
||||
and agent_config.trust_remote_completion_status
|
||||
):
|
||||
if (
|
||||
task_id_config is not None
|
||||
and task_id_config not in reference_task_ids
|
||||
):
|
||||
reference_task_ids.append(task_id_config)
|
||||
if task.config is None:
|
||||
task.config = {}
|
||||
task.config["reference_task_ids"] = reference_task_ids
|
||||
|
||||
result_text = a2a_result.get("result", "")
|
||||
final_turn_number = turn_num + 1
|
||||
crewai_event_bus.emit(
|
||||
@@ -513,7 +604,7 @@ def _delegate_to_a2a(
|
||||
total_turns=final_turn_number,
|
||||
),
|
||||
)
|
||||
return result_text # type: ignore[no-any-return]
|
||||
return cast(str, result_text)
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
@@ -541,6 +632,31 @@ def _delegate_to_a2a(
|
||||
continue
|
||||
|
||||
error_msg = a2a_result.get("error", "Unknown error")
|
||||
|
||||
final_result, next_request = _handle_agent_response_and_continue(
|
||||
self=self,
|
||||
a2a_result=a2a_result,
|
||||
agent_id=agent_id,
|
||||
agent_cards=agent_cards,
|
||||
a2a_agents=a2a_agents,
|
||||
original_task_description=original_task_description,
|
||||
conversation_history=conversation_history,
|
||||
turn_num=turn_num,
|
||||
max_turns=max_turns,
|
||||
task=task,
|
||||
original_fn=original_fn,
|
||||
context=context,
|
||||
tools=tools,
|
||||
agent_response_model=agent_response_model,
|
||||
)
|
||||
|
||||
if final_result is not None:
|
||||
return final_result
|
||||
|
||||
if next_request is not None:
|
||||
current_request = next_request
|
||||
continue
|
||||
|
||||
crewai_event_bus.emit(
|
||||
None,
|
||||
A2AConversationCompletedEvent(
|
||||
@@ -550,7 +666,7 @@ def _delegate_to_a2a(
|
||||
total_turns=turn_num + 1,
|
||||
),
|
||||
)
|
||||
raise Exception(f"A2A delegation failed: {error_msg}")
|
||||
return f"A2A delegation failed: {error_msg}"
|
||||
|
||||
if conversation_history:
|
||||
for msg in reversed(conversation_history):
|
||||
|
||||
@@ -2,7 +2,6 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Sequence
|
||||
import json
|
||||
import shutil
|
||||
import subprocess
|
||||
import time
|
||||
@@ -19,6 +18,19 @@ from pydantic import BaseModel, Field, InstanceOf, PrivateAttr, model_validator
|
||||
from typing_extensions import Self
|
||||
|
||||
from crewai.a2a.config import A2AConfig
|
||||
from crewai.agent.utils import (
|
||||
ahandle_knowledge_retrieval,
|
||||
apply_training_data,
|
||||
build_task_prompt_with_schema,
|
||||
format_task_with_context,
|
||||
get_knowledge_config,
|
||||
handle_knowledge_retrieval,
|
||||
handle_reasoning,
|
||||
prepare_tools,
|
||||
process_tool_results,
|
||||
save_last_messages,
|
||||
validate_max_execution_time,
|
||||
)
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
@@ -27,9 +39,6 @@ from crewai.events.types.knowledge_events import (
|
||||
KnowledgeQueryCompletedEvent,
|
||||
KnowledgeQueryFailedEvent,
|
||||
KnowledgeQueryStartedEvent,
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.events.types.memory_events import (
|
||||
MemoryRetrievalCompletedEvent,
|
||||
@@ -37,7 +46,6 @@ from crewai.events.types.memory_events import (
|
||||
)
|
||||
from crewai.knowledge.knowledge import Knowledge
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.mcp import (
|
||||
@@ -61,7 +69,7 @@ from crewai.utilities.agent_utils import (
|
||||
render_text_description_and_args,
|
||||
)
|
||||
from crewai.utilities.constants import TRAINED_AGENTS_DATA_FILE, TRAINING_DATA_FILE
|
||||
from crewai.utilities.converter import Converter, generate_model_description
|
||||
from crewai.utilities.converter import Converter
|
||||
from crewai.utilities.guardrail_types import GuardrailType
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.prompts import Prompts
|
||||
@@ -295,53 +303,15 @@ class Agent(BaseAgent):
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
if self.reasoning:
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=self)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
|
||||
# Add the reasoning plan to the task description
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
self._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
|
||||
# If the task requires output in JSON or Pydantic format,
|
||||
# append specific instructions to the task prompt to ensure
|
||||
# that the final answer does not include any code block markers
|
||||
# Skip this if task.response_model is set, as native structured outputs handle schema automatically
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
# Generate the schema based on the output format
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + self.i18n.slice(
|
||||
"formatted_task_instructions"
|
||||
).format(output_format=schema)
|
||||
|
||||
if context:
|
||||
task_prompt = self.i18n.slice("task_with_context").format(
|
||||
task=task_prompt, context=context
|
||||
)
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
@@ -379,84 +349,20 @@ class Agent(BaseAgent):
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
knowledge_config = (
|
||||
self.knowledge_config.model_dump() if self.knowledge_config else {}
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = handle_knowledge_retrieval(
|
||||
self,
|
||||
task,
|
||||
task_prompt,
|
||||
knowledge_config,
|
||||
self.knowledge.query if self.knowledge else lambda *a, **k: None,
|
||||
self.crew.query_knowledge if self.crew else lambda *a, **k: None,
|
||||
)
|
||||
|
||||
if self.knowledge or (self.crew and self.crew.knowledge):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
try:
|
||||
self.knowledge_search_query = self._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if self.knowledge_search_query:
|
||||
# Quering agent specific knowledge
|
||||
if self.knowledge:
|
||||
agent_knowledge_snippets = self.knowledge.query(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
self.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if self.agent_knowledge_context:
|
||||
task_prompt += self.agent_knowledge_context
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
# Quering crew specific knowledge
|
||||
knowledge_snippets = self.crew.query_knowledge(
|
||||
[self.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
self.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if self.crew_knowledge_context:
|
||||
task_prompt += self.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=self.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
retrieved_knowledge=(
|
||||
(self.agent_knowledge_context or "")
|
||||
+ (
|
||||
"\n"
|
||||
if self.agent_knowledge_context
|
||||
and self.crew_knowledge_context
|
||||
else ""
|
||||
)
|
||||
+ (self.crew_knowledge_context or "")
|
||||
),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=self.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=self,
|
||||
),
|
||||
)
|
||||
|
||||
tools = tools or self.tools or []
|
||||
self.create_agent_executor(tools=tools, task=task)
|
||||
|
||||
if self.crew and self.crew._train:
|
||||
task_prompt = self._training_handler(task_prompt=task_prompt)
|
||||
else:
|
||||
task_prompt = self._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
# Import agent events locally to avoid circular imports
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
@@ -474,15 +380,8 @@ class Agent(BaseAgent):
|
||||
),
|
||||
)
|
||||
|
||||
# Determine execution method based on timeout setting
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
if (
|
||||
not isinstance(self.max_execution_time, int)
|
||||
or self.max_execution_time <= 0
|
||||
):
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
result = self._execute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
@@ -490,7 +389,6 @@ class Agent(BaseAgent):
|
||||
result = self._execute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
# Propagate TimeoutError without retry
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -502,7 +400,6 @@ class Agent(BaseAgent):
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
# Do not retry on litellm errors
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
@@ -528,23 +425,13 @@ class Agent(BaseAgent):
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
# If there was any tool in self.tools_results that had result_as_answer
|
||||
# set to True, return the results of the last tool that had
|
||||
# result_as_answer set to True
|
||||
for tool_result in self.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
self._last_messages = (
|
||||
self.agent_executor.messages.copy()
|
||||
if self.agent_executor and hasattr(self.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
@@ -604,6 +491,208 @@ class Agent(BaseAgent):
|
||||
}
|
||||
)["output"]
|
||||
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Task,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> Any:
|
||||
"""Execute a task with the agent asynchronously.
|
||||
|
||||
Args:
|
||||
task: Task to execute.
|
||||
context: Context to execute the task in.
|
||||
tools: Tools to use for the task.
|
||||
|
||||
Returns:
|
||||
Output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the maximum execution time.
|
||||
ValueError: If the max execution time is not a positive integer.
|
||||
RuntimeError: If the agent execution fails for other reasons.
|
||||
"""
|
||||
handle_reasoning(self, task)
|
||||
self._inject_date_to_task(task)
|
||||
|
||||
if self.tools_handler:
|
||||
self.tools_handler.last_used_tool = None
|
||||
|
||||
task_prompt = task.prompt()
|
||||
task_prompt = build_task_prompt_with_schema(task, task_prompt, self.i18n)
|
||||
task_prompt = format_task_with_context(task_prompt, context, self.i18n)
|
||||
|
||||
if self._is_any_available_memory():
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalStartedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
contextual_memory = ContextualMemory(
|
||||
self.crew._short_term_memory,
|
||||
self.crew._long_term_memory,
|
||||
self.crew._entity_memory,
|
||||
self.crew._external_memory,
|
||||
agent=self,
|
||||
task=task,
|
||||
)
|
||||
memory = await contextual_memory.abuild_context_for_task(
|
||||
task, context or ""
|
||||
)
|
||||
if memory.strip() != "":
|
||||
task_prompt += self.i18n.slice("memory").format(memory=memory)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryRetrievalCompletedEvent(
|
||||
task_id=str(task.id) if task else None,
|
||||
memory_content=memory,
|
||||
retrieval_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="agent",
|
||||
from_agent=self,
|
||||
from_task=task,
|
||||
),
|
||||
)
|
||||
|
||||
knowledge_config = get_knowledge_config(self)
|
||||
task_prompt = await ahandle_knowledge_retrieval(
|
||||
self, task, task_prompt, knowledge_config
|
||||
)
|
||||
|
||||
prepare_tools(self, tools, task)
|
||||
task_prompt = apply_training_data(self, task_prompt)
|
||||
|
||||
from crewai.events.types.agent_events import (
|
||||
AgentExecutionCompletedEvent,
|
||||
AgentExecutionErrorEvent,
|
||||
AgentExecutionStartedEvent,
|
||||
)
|
||||
|
||||
try:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionStartedEvent(
|
||||
agent=self,
|
||||
tools=self.tools,
|
||||
task_prompt=task_prompt,
|
||||
task=task,
|
||||
),
|
||||
)
|
||||
|
||||
validate_max_execution_time(self.max_execution_time)
|
||||
if self.max_execution_time is not None:
|
||||
result = await self._aexecute_with_timeout(
|
||||
task_prompt, task, self.max_execution_time
|
||||
)
|
||||
else:
|
||||
result = await self._aexecute_without_timeout(task_prompt, task)
|
||||
|
||||
except TimeoutError as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
self._times_executed += 1
|
||||
if self._times_executed > self.max_retry_limit:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionErrorEvent(
|
||||
agent=self,
|
||||
task=task,
|
||||
error=str(e),
|
||||
),
|
||||
)
|
||||
raise e
|
||||
result = await self.aexecute_task(task, context, tools)
|
||||
|
||||
if self.max_rpm and self._rpm_controller:
|
||||
self._rpm_controller.stop_rpm_counter()
|
||||
|
||||
result = process_tool_results(self, result)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=AgentExecutionCompletedEvent(agent=self, task=task, output=result),
|
||||
)
|
||||
|
||||
save_last_messages(self)
|
||||
self._cleanup_mcp_clients()
|
||||
|
||||
return result
|
||||
|
||||
async def _aexecute_with_timeout(
|
||||
self, task_prompt: str, task: Task, timeout: int
|
||||
) -> Any:
|
||||
"""Execute a task with a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
timeout: Maximum execution time in seconds.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
|
||||
Raises:
|
||||
TimeoutError: If execution exceeds the timeout.
|
||||
RuntimeError: If execution fails for other reasons.
|
||||
"""
|
||||
try:
|
||||
return await asyncio.wait_for(
|
||||
self._aexecute_without_timeout(task_prompt, task),
|
||||
timeout=timeout,
|
||||
)
|
||||
except asyncio.TimeoutError as e:
|
||||
raise TimeoutError(
|
||||
f"Task '{task.description}' execution timed out after {timeout} seconds. "
|
||||
"Consider increasing max_execution_time or optimizing the task."
|
||||
) from e
|
||||
|
||||
async def _aexecute_without_timeout(self, task_prompt: str, task: Task) -> Any:
|
||||
"""Execute a task without a timeout asynchronously.
|
||||
|
||||
Args:
|
||||
task_prompt: The prompt to send to the agent.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The output of the agent.
|
||||
"""
|
||||
if not self.agent_executor:
|
||||
raise RuntimeError("Agent executor is not initialized.")
|
||||
|
||||
result = await self.agent_executor.ainvoke(
|
||||
{
|
||||
"input": task_prompt,
|
||||
"tool_names": self.agent_executor.tools_names,
|
||||
"tools": self.agent_executor.tools_description,
|
||||
"ask_for_human_input": task.human_input,
|
||||
}
|
||||
)
|
||||
return result["output"]
|
||||
|
||||
def create_agent_executor(
|
||||
self, tools: list[BaseTool] | None = None, task: Task | None = None
|
||||
) -> None:
|
||||
@@ -633,7 +722,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
self.agent_executor = CrewAgentExecutor(
|
||||
llm=self.llm,
|
||||
llm=self.llm, # type: ignore[arg-type]
|
||||
task=task, # type: ignore[arg-type]
|
||||
agent=self,
|
||||
crew=self.crew,
|
||||
@@ -810,6 +899,7 @@ class Agent(BaseAgent):
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.mcp_native_tool import MCPNativeTool
|
||||
|
||||
transport: StdioTransport | HTTPTransport | SSETransport
|
||||
if isinstance(mcp_config, MCPServerStdio):
|
||||
transport = StdioTransport(
|
||||
command=mcp_config.command,
|
||||
@@ -903,10 +993,10 @@ class Agent(BaseAgent):
|
||||
server_name=server_name,
|
||||
run_context=None,
|
||||
)
|
||||
if mcp_config.tool_filter(context, tool):
|
||||
if mcp_config.tool_filter(context, tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
except (TypeError, AttributeError):
|
||||
if mcp_config.tool_filter(tool):
|
||||
if mcp_config.tool_filter(tool): # type: ignore[call-arg, arg-type]
|
||||
filtered_tools.append(tool)
|
||||
else:
|
||||
# Not callable - include tool
|
||||
@@ -981,7 +1071,9 @@ class Agent(BaseAgent):
|
||||
path = parsed.path.replace("/", "_").strip("_")
|
||||
return f"{domain}_{path}" if path else domain
|
||||
|
||||
def _get_mcp_tool_schemas(self, server_params: dict) -> dict[str, dict]:
|
||||
def _get_mcp_tool_schemas(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Get tool schemas from MCP server for wrapper creation with caching."""
|
||||
server_url = server_params["url"]
|
||||
|
||||
@@ -995,7 +1087,7 @@ class Agent(BaseAgent):
|
||||
self._logger.log(
|
||||
"debug", f"Using cached MCP tool schemas for {server_url}"
|
||||
)
|
||||
return cached_data
|
||||
return cached_data # type: ignore[no-any-return]
|
||||
|
||||
try:
|
||||
schemas = asyncio.run(self._get_mcp_tool_schemas_async(server_params))
|
||||
@@ -1013,7 +1105,7 @@ class Agent(BaseAgent):
|
||||
|
||||
async def _get_mcp_tool_schemas_async(
|
||||
self, server_params: dict[str, Any]
|
||||
) -> dict[str, dict]:
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Async implementation of MCP tool schema retrieval with timeouts and retries."""
|
||||
server_url = server_params["url"]
|
||||
return await self._retry_mcp_discovery(
|
||||
@@ -1021,7 +1113,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
async def _retry_mcp_discovery(
|
||||
self, operation_func, server_url: str
|
||||
self, operation_func: Any, server_url: str
|
||||
) -> dict[str, dict[str, Any]]:
|
||||
"""Retry MCP discovery operation with exponential backoff, avoiding try-except in loop."""
|
||||
last_error = None
|
||||
@@ -1052,7 +1144,7 @@ class Agent(BaseAgent):
|
||||
|
||||
@staticmethod
|
||||
async def _attempt_mcp_discovery(
|
||||
operation_func, server_url: str
|
||||
operation_func: Any, server_url: str
|
||||
) -> tuple[dict[str, dict[str, Any]] | None, str, bool]:
|
||||
"""Attempt single MCP discovery operation and return (result, error_message, should_retry)."""
|
||||
try:
|
||||
@@ -1142,7 +1234,7 @@ class Agent(BaseAgent):
|
||||
properties = json_schema.get("properties", {})
|
||||
required_fields = json_schema.get("required", [])
|
||||
|
||||
field_definitions = {}
|
||||
field_definitions: dict[str, Any] = {}
|
||||
|
||||
for field_name, field_schema in properties.items():
|
||||
field_type = self._json_type_to_python(field_schema)
|
||||
@@ -1162,7 +1254,7 @@ class Agent(BaseAgent):
|
||||
)
|
||||
|
||||
model_name = f"{tool_name.replace('-', '_').replace(' ', '_')}Schema"
|
||||
return create_model(model_name, **field_definitions)
|
||||
return create_model(model_name, **field_definitions) # type: ignore[no-any-return]
|
||||
|
||||
def _json_type_to_python(self, field_schema: dict[str, Any]) -> type:
|
||||
"""Convert JSON Schema type to Python type.
|
||||
@@ -1177,7 +1269,7 @@ class Agent(BaseAgent):
|
||||
json_type = field_schema.get("type")
|
||||
|
||||
if "anyOf" in field_schema:
|
||||
types = []
|
||||
types: list[type] = []
|
||||
for option in field_schema["anyOf"]:
|
||||
if "const" in option:
|
||||
types.append(str)
|
||||
@@ -1185,13 +1277,13 @@ class Agent(BaseAgent):
|
||||
types.append(self._json_type_to_python(option))
|
||||
unique_types = list(set(types))
|
||||
if len(unique_types) > 1:
|
||||
result = unique_types[0]
|
||||
result: Any = unique_types[0]
|
||||
for t in unique_types[1:]:
|
||||
result = result | t
|
||||
return result
|
||||
return result # type: ignore[no-any-return]
|
||||
return unique_types[0]
|
||||
|
||||
type_mapping = {
|
||||
type_mapping: dict[str | None, type] = {
|
||||
"string": str,
|
||||
"number": float,
|
||||
"integer": int,
|
||||
@@ -1203,7 +1295,7 @@ class Agent(BaseAgent):
|
||||
return type_mapping.get(json_type, Any)
|
||||
|
||||
@staticmethod
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict]:
|
||||
def _fetch_amp_mcp_servers(mcp_name: str) -> list[dict[str, Any]]:
|
||||
"""Fetch MCP server configurations from CrewAI AOP API."""
|
||||
# TODO: Implement AMP API call to "integrations/mcps" endpoint
|
||||
# Should return list of server configs with URLs
|
||||
@@ -1438,11 +1530,11 @@ class Agent(BaseAgent):
|
||||
"""
|
||||
if self.apps:
|
||||
platform_tools = self.get_platform_tools(self.apps)
|
||||
if platform_tools:
|
||||
if platform_tools and self.tools is not None:
|
||||
self.tools.extend(platform_tools)
|
||||
if self.mcps:
|
||||
mcps = self.get_mcp_tools(self.mcps)
|
||||
if mcps:
|
||||
if mcps and self.tools is not None:
|
||||
self.tools.extend(mcps)
|
||||
|
||||
lite_agent = LiteAgent(
|
||||
|
||||
@@ -4,9 +4,8 @@ This metaclass enables extension capabilities for agents by detecting
|
||||
extension fields in class annotations and applying appropriate wrappers.
|
||||
"""
|
||||
|
||||
import warnings
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
import warnings
|
||||
|
||||
from pydantic import model_validator
|
||||
from pydantic._internal._model_construction import ModelMetaclass
|
||||
@@ -59,9 +58,15 @@ class AgentMeta(ModelMetaclass):
|
||||
|
||||
a2a_value = getattr(self, "a2a", None)
|
||||
if a2a_value is not None:
|
||||
from crewai.a2a.extensions.registry import (
|
||||
create_extension_registry_from_config,
|
||||
)
|
||||
from crewai.a2a.wrapper import wrap_agent_with_a2a_instance
|
||||
|
||||
wrap_agent_with_a2a_instance(self)
|
||||
extension_registry = create_extension_registry_from_config(
|
||||
a2a_value
|
||||
)
|
||||
wrap_agent_with_a2a_instance(self, extension_registry)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
355
lib/crewai/src/crewai/agent/utils.py
Normal file
355
lib/crewai/src/crewai/agent/utils.py
Normal file
@@ -0,0 +1,355 @@
|
||||
"""Utility functions for agent task execution.
|
||||
|
||||
This module contains shared logic extracted from the Agent's execute_task
|
||||
and aexecute_task methods to reduce code duplication.
|
||||
"""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.knowledge_events import (
|
||||
KnowledgeRetrievalCompletedEvent,
|
||||
KnowledgeRetrievalStartedEvent,
|
||||
KnowledgeSearchQueryFailedEvent,
|
||||
)
|
||||
from crewai.knowledge.utils.knowledge_utils import extract_knowledge_context
|
||||
from crewai.utilities.converter import generate_model_description
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.i18n import I18N
|
||||
|
||||
|
||||
def handle_reasoning(agent: Agent, task: Task) -> None:
|
||||
"""Handle the reasoning process for an agent before task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task to execute.
|
||||
"""
|
||||
if not agent.reasoning:
|
||||
return
|
||||
|
||||
try:
|
||||
from crewai.utilities.reasoning_handler import (
|
||||
AgentReasoning,
|
||||
AgentReasoningOutput,
|
||||
)
|
||||
|
||||
reasoning_handler = AgentReasoning(task=task, agent=agent)
|
||||
reasoning_output: AgentReasoningOutput = (
|
||||
reasoning_handler.handle_agent_reasoning()
|
||||
)
|
||||
task.description += f"\n\nReasoning Plan:\n{reasoning_output.plan.plan}"
|
||||
except Exception as e:
|
||||
agent._logger.log("error", f"Error during reasoning process: {e!s}")
|
||||
|
||||
|
||||
def build_task_prompt_with_schema(task: Task, task_prompt: str, i18n: I18N) -> str:
|
||||
"""Build task prompt with JSON/Pydantic schema instructions if applicable.
|
||||
|
||||
Args:
|
||||
task: The task being executed.
|
||||
task_prompt: The initial task prompt.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with schema instructions.
|
||||
"""
|
||||
if (task.output_json or task.output_pydantic) and not task.response_model:
|
||||
if task.output_json:
|
||||
schema_dict = generate_model_description(task.output_json)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
elif task.output_pydantic:
|
||||
schema_dict = generate_model_description(task.output_pydantic)
|
||||
schema = json.dumps(schema_dict["json_schema"]["schema"], indent=2)
|
||||
task_prompt += "\n" + i18n.slice("formatted_task_instructions").format(
|
||||
output_format=schema
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def format_task_with_context(task_prompt: str, context: str | None, i18n: I18N) -> str:
|
||||
"""Format task prompt with context if provided.
|
||||
|
||||
Args:
|
||||
task_prompt: The task prompt.
|
||||
context: Optional context string.
|
||||
i18n: Internationalization instance.
|
||||
|
||||
Returns:
|
||||
The task prompt formatted with context if provided.
|
||||
"""
|
||||
if context:
|
||||
return i18n.slice("task_with_context").format(task=task_prompt, context=context)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def get_knowledge_config(agent: Agent) -> dict[str, Any]:
|
||||
"""Get knowledge configuration from agent.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
|
||||
Returns:
|
||||
Dictionary of knowledge configuration.
|
||||
"""
|
||||
return agent.knowledge_config.model_dump() if agent.knowledge_config else {}
|
||||
|
||||
|
||||
def handle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
query_func: Any,
|
||||
crew_query_func: Any,
|
||||
) -> str:
|
||||
"""Handle knowledge retrieval for task execution.
|
||||
|
||||
This function handles both agent-specific and crew-specific knowledge queries.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
query_func: Function to query agent knowledge (sync or async).
|
||||
crew_query_func: Function to query crew knowledge (sync or async).
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = crew_query_func(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
|
||||
|
||||
def _combine_knowledge_context(agent: Agent) -> str:
|
||||
"""Combine agent and crew knowledge contexts into a single string.
|
||||
|
||||
Args:
|
||||
agent: The agent with knowledge contexts.
|
||||
|
||||
Returns:
|
||||
Combined knowledge context string.
|
||||
"""
|
||||
agent_ctx = agent.agent_knowledge_context or ""
|
||||
crew_ctx = agent.crew_knowledge_context or ""
|
||||
separator = "\n" if agent_ctx and crew_ctx else ""
|
||||
return agent_ctx + separator + crew_ctx
|
||||
|
||||
|
||||
def apply_training_data(agent: Agent, task_prompt: str) -> str:
|
||||
"""Apply training data to the task prompt.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task_prompt: The task prompt.
|
||||
|
||||
Returns:
|
||||
The task prompt with training data applied.
|
||||
"""
|
||||
if agent.crew and agent.crew._train:
|
||||
return agent._training_handler(task_prompt=task_prompt)
|
||||
return agent._use_trained_data(task_prompt=task_prompt)
|
||||
|
||||
|
||||
def process_tool_results(agent: Agent, result: Any) -> Any:
|
||||
"""Process tool results, returning result_as_answer if applicable.
|
||||
|
||||
Args:
|
||||
agent: The agent with tool results.
|
||||
result: The current result.
|
||||
|
||||
Returns:
|
||||
The final result, potentially overridden by tool result_as_answer.
|
||||
"""
|
||||
for tool_result in agent.tools_results:
|
||||
if tool_result.get("result_as_answer", False):
|
||||
result = tool_result["result"]
|
||||
return result
|
||||
|
||||
|
||||
def save_last_messages(agent: Agent) -> None:
|
||||
"""Save the last messages from agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
"""
|
||||
agent._last_messages = (
|
||||
agent.agent_executor.messages.copy()
|
||||
if agent.agent_executor and hasattr(agent.agent_executor, "messages")
|
||||
else []
|
||||
)
|
||||
|
||||
|
||||
def prepare_tools(
|
||||
agent: Agent, tools: list[BaseTool] | None, task: Task
|
||||
) -> list[BaseTool]:
|
||||
"""Prepare tools for task execution and create agent executor.
|
||||
|
||||
Args:
|
||||
agent: The agent instance.
|
||||
tools: Optional list of tools.
|
||||
task: The task being executed.
|
||||
|
||||
Returns:
|
||||
The list of tools to use.
|
||||
"""
|
||||
final_tools = tools or agent.tools or []
|
||||
agent.create_agent_executor(tools=final_tools, task=task)
|
||||
return final_tools
|
||||
|
||||
|
||||
def validate_max_execution_time(max_execution_time: int | None) -> None:
|
||||
"""Validate max_execution_time parameter.
|
||||
|
||||
Args:
|
||||
max_execution_time: The maximum execution time to validate.
|
||||
|
||||
Raises:
|
||||
ValueError: If max_execution_time is not a positive integer.
|
||||
"""
|
||||
if max_execution_time is not None:
|
||||
if not isinstance(max_execution_time, int) or max_execution_time <= 0:
|
||||
raise ValueError(
|
||||
"Max Execution time must be a positive integer greater than zero"
|
||||
)
|
||||
|
||||
|
||||
async def ahandle_knowledge_retrieval(
|
||||
agent: Agent,
|
||||
task: Task,
|
||||
task_prompt: str,
|
||||
knowledge_config: dict[str, Any],
|
||||
) -> str:
|
||||
"""Handle async knowledge retrieval for task execution.
|
||||
|
||||
Args:
|
||||
agent: The agent performing the task.
|
||||
task: The task being executed.
|
||||
task_prompt: The current task prompt.
|
||||
knowledge_config: Knowledge configuration dictionary.
|
||||
|
||||
Returns:
|
||||
The task prompt potentially augmented with knowledge context.
|
||||
"""
|
||||
if not (agent.knowledge or (agent.crew and agent.crew.knowledge)):
|
||||
return task_prompt
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalStartedEvent(
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
try:
|
||||
agent.knowledge_search_query = agent._get_knowledge_search_query(
|
||||
task_prompt, task
|
||||
)
|
||||
if agent.knowledge_search_query:
|
||||
if agent.knowledge:
|
||||
agent_knowledge_snippets = await agent.knowledge.aquery(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if agent_knowledge_snippets:
|
||||
agent.agent_knowledge_context = extract_knowledge_context(
|
||||
agent_knowledge_snippets
|
||||
)
|
||||
if agent.agent_knowledge_context:
|
||||
task_prompt += agent.agent_knowledge_context
|
||||
|
||||
knowledge_snippets = await agent.crew.aquery_knowledge(
|
||||
[agent.knowledge_search_query], **knowledge_config
|
||||
)
|
||||
if knowledge_snippets:
|
||||
agent.crew_knowledge_context = extract_knowledge_context(
|
||||
knowledge_snippets
|
||||
)
|
||||
if agent.crew_knowledge_context:
|
||||
task_prompt += agent.crew_knowledge_context
|
||||
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeRetrievalCompletedEvent(
|
||||
query=agent.knowledge_search_query,
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
retrieved_knowledge=_combine_knowledge_context(agent),
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
agent,
|
||||
event=KnowledgeSearchQueryFailedEvent(
|
||||
query=agent.knowledge_search_query or "",
|
||||
error=str(e),
|
||||
from_task=task,
|
||||
from_agent=agent,
|
||||
),
|
||||
)
|
||||
return task_prompt
|
||||
@@ -265,7 +265,7 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
if not mcps:
|
||||
return mcps
|
||||
|
||||
validated_mcps = []
|
||||
validated_mcps: list[str | MCPServerConfig] = []
|
||||
for mcp in mcps:
|
||||
if isinstance(mcp, str):
|
||||
if mcp.startswith(("https://", "crewai-amp:")):
|
||||
@@ -347,6 +347,15 @@ class BaseAgent(BaseModel, ABC, metaclass=AgentMeta):
|
||||
) -> str:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def aexecute_task(
|
||||
self,
|
||||
task: Any,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> str:
|
||||
"""Execute a task asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def create_agent_executor(self, tools: list[BaseTool] | None = None) -> None:
|
||||
pass
|
||||
|
||||
@@ -28,6 +28,7 @@ from crewai.hooks.llm_hooks import (
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.agent_utils import (
|
||||
aget_llm_response,
|
||||
enforce_rpm_limit,
|
||||
format_message_for_llm,
|
||||
get_llm_response,
|
||||
@@ -43,7 +44,10 @@ from crewai.utilities.agent_utils import (
|
||||
from crewai.utilities.constants import TRAINING_DATA_FILE
|
||||
from crewai.utilities.i18n import I18N, get_i18n
|
||||
from crewai.utilities.printer import Printer
|
||||
from crewai.utilities.tool_utils import execute_tool_and_check_finality
|
||||
from crewai.utilities.tool_utils import (
|
||||
aexecute_tool_and_check_finality,
|
||||
execute_tool_and_check_finality,
|
||||
)
|
||||
from crewai.utilities.training_handler import CrewTrainingHandler
|
||||
|
||||
|
||||
@@ -134,8 +138,8 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self.messages: list[LLMMessage] = []
|
||||
self.iterations = 0
|
||||
self.log_error_after = 3
|
||||
self.before_llm_call_hooks: list[Callable] = []
|
||||
self.after_llm_call_hooks: list[Callable] = []
|
||||
self.before_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.after_llm_call_hooks: list[Callable[..., Any]] = []
|
||||
self.before_llm_call_hooks.extend(get_before_llm_call_hooks())
|
||||
self.after_llm_call_hooks.extend(get_after_llm_call_hooks())
|
||||
if self.llm:
|
||||
@@ -312,6 +316,154 @@ class CrewAgentExecutor(CrewAgentExecutorMixin):
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
async def ainvoke(self, inputs: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Execute the agent asynchronously with given inputs.
|
||||
|
||||
Args:
|
||||
inputs: Input dictionary containing prompt variables.
|
||||
|
||||
Returns:
|
||||
Dictionary with agent output.
|
||||
"""
|
||||
if "system" in self.prompt:
|
||||
system_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("system", "")), inputs
|
||||
)
|
||||
user_prompt = self._format_prompt(
|
||||
cast(str, self.prompt.get("user", "")), inputs
|
||||
)
|
||||
self.messages.append(format_message_for_llm(system_prompt, role="system"))
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
else:
|
||||
user_prompt = self._format_prompt(self.prompt.get("prompt", ""), inputs)
|
||||
self.messages.append(format_message_for_llm(user_prompt))
|
||||
|
||||
self._show_start_logs()
|
||||
|
||||
self.ask_for_human_input = bool(inputs.get("ask_for_human_input", False))
|
||||
|
||||
try:
|
||||
formatted_answer = await self._ainvoke_loop()
|
||||
except AssertionError:
|
||||
self._printer.print(
|
||||
content="Agent failed to reach a final answer. This is likely a bug - please report it.",
|
||||
color="red",
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise
|
||||
|
||||
if self.ask_for_human_input:
|
||||
formatted_answer = self._handle_human_feedback(formatted_answer)
|
||||
|
||||
self._create_short_term_memory(formatted_answer)
|
||||
self._create_long_term_memory(formatted_answer)
|
||||
self._create_external_memory(formatted_answer)
|
||||
return {"output": formatted_answer.output}
|
||||
|
||||
async def _ainvoke_loop(self) -> AgentFinish:
|
||||
"""Execute agent loop asynchronously until completion.
|
||||
|
||||
Returns:
|
||||
Final answer from the agent.
|
||||
"""
|
||||
formatted_answer = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self.iterations, self.max_iter):
|
||||
formatted_answer = handle_max_iterations_exceeded(
|
||||
formatted_answer,
|
||||
printer=self._printer,
|
||||
i18n=self._i18n,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
)
|
||||
break
|
||||
|
||||
enforce_rpm_limit(self.request_within_rpm_limit)
|
||||
|
||||
answer = await aget_llm_response(
|
||||
llm=self.llm,
|
||||
messages=self.messages,
|
||||
callbacks=self.callbacks,
|
||||
printer=self._printer,
|
||||
from_task=self.task,
|
||||
from_agent=self.agent,
|
||||
response_model=self.response_model,
|
||||
executor_context=self,
|
||||
)
|
||||
formatted_answer = process_llm_response(answer, self.use_stop_words) # type: ignore[assignment]
|
||||
|
||||
if isinstance(formatted_answer, AgentAction):
|
||||
fingerprint_context = {}
|
||||
if (
|
||||
self.agent
|
||||
and hasattr(self.agent, "security_config")
|
||||
and hasattr(self.agent.security_config, "fingerprint")
|
||||
):
|
||||
fingerprint_context = {
|
||||
"agent_fingerprint": str(
|
||||
self.agent.security_config.fingerprint
|
||||
)
|
||||
}
|
||||
|
||||
tool_result = await aexecute_tool_and_check_finality(
|
||||
agent_action=formatted_answer,
|
||||
fingerprint_context=fingerprint_context,
|
||||
tools=self.tools,
|
||||
i18n=self._i18n,
|
||||
agent_key=self.agent.key if self.agent else None,
|
||||
agent_role=self.agent.role if self.agent else None,
|
||||
tools_handler=self.tools_handler,
|
||||
task=self.task,
|
||||
agent=self.agent,
|
||||
function_calling_llm=self.function_calling_llm,
|
||||
crew=self.crew,
|
||||
)
|
||||
formatted_answer = self._handle_agent_action(
|
||||
formatted_answer, tool_result
|
||||
)
|
||||
|
||||
self._invoke_step_callback(formatted_answer) # type: ignore[arg-type]
|
||||
self._append_message(formatted_answer.text) # type: ignore[union-attr,attr-defined]
|
||||
|
||||
except OutputParserError as e:
|
||||
formatted_answer = handle_output_parser_exception( # type: ignore[assignment]
|
||||
e=e,
|
||||
messages=self.messages,
|
||||
iterations=self.iterations,
|
||||
log_error_after=self.log_error_after,
|
||||
printer=self._printer,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if e.__class__.__module__.startswith("litellm"):
|
||||
raise e
|
||||
if is_context_length_exceeded(e):
|
||||
handle_context_length(
|
||||
respect_context_window=self.respect_context_window,
|
||||
printer=self._printer,
|
||||
messages=self.messages,
|
||||
llm=self.llm,
|
||||
callbacks=self.callbacks,
|
||||
i18n=self._i18n,
|
||||
)
|
||||
continue
|
||||
handle_unknown_error(self._printer, e)
|
||||
raise e
|
||||
finally:
|
||||
self.iterations += 1
|
||||
|
||||
if not isinstance(formatted_answer, AgentFinish):
|
||||
raise RuntimeError(
|
||||
"Agent execution ended without reaching a final answer. "
|
||||
f"Got {type(formatted_answer).__name__} instead of AgentFinish."
|
||||
)
|
||||
self._show_logs(formatted_answer)
|
||||
return formatted_answer
|
||||
|
||||
def _handle_agent_action(
|
||||
self, formatted_answer: AgentAction, tool_result: ToolResult
|
||||
) -> AgentAction | AgentFinish:
|
||||
|
||||
@@ -14,7 +14,8 @@ import tomli
|
||||
from crewai.cli.utils import read_toml
|
||||
from crewai.cli.version import get_crewai_version
|
||||
from crewai.crew import Crew
|
||||
from crewai.llm import LLM, BaseLLM
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.types.crew_chat import ChatInputField, ChatInputs
|
||||
from crewai.utilities.llm_utils import create_llm
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -27,7 +28,7 @@ MIN_REQUIRED_VERSION: Final[Literal["0.98.0"]] = "0.98.0"
|
||||
|
||||
|
||||
def check_conversational_crews_version(
|
||||
crewai_version: str, pyproject_data: dict
|
||||
crewai_version: str, pyproject_data: dict[str, Any]
|
||||
) -> bool:
|
||||
"""
|
||||
Check if the installed crewAI version supports conversational crews.
|
||||
@@ -53,7 +54,7 @@ def check_conversational_crews_version(
|
||||
return True
|
||||
|
||||
|
||||
def run_chat():
|
||||
def run_chat() -> None:
|
||||
"""
|
||||
Runs an interactive chat loop using the Crew's chat LLM with function calling.
|
||||
Incorporates crew_name, crew_description, and input fields to build a tool schema.
|
||||
@@ -101,7 +102,7 @@ def run_chat():
|
||||
|
||||
click.secho(f"Assistant: {introductory_message}\n", fg="green")
|
||||
|
||||
messages = [
|
||||
messages: list[LLMMessage] = [
|
||||
{"role": "system", "content": system_message},
|
||||
{"role": "assistant", "content": introductory_message},
|
||||
]
|
||||
@@ -113,7 +114,7 @@ def run_chat():
|
||||
chat_loop(chat_llm, messages, crew_tool_schema, available_functions)
|
||||
|
||||
|
||||
def show_loading(event: threading.Event):
|
||||
def show_loading(event: threading.Event) -> None:
|
||||
"""Display animated loading dots while processing."""
|
||||
while not event.is_set():
|
||||
_printer.print(".", end="")
|
||||
@@ -162,23 +163,23 @@ def build_system_message(crew_chat_inputs: ChatInputs) -> str:
|
||||
)
|
||||
|
||||
|
||||
def create_tool_function(crew: Crew, messages: list[dict[str, str]]) -> Any:
|
||||
def create_tool_function(crew: Crew, messages: list[LLMMessage]) -> Any:
|
||||
"""Creates a wrapper function for running the crew tool with messages."""
|
||||
|
||||
def run_crew_tool_with_messages(**kwargs):
|
||||
def run_crew_tool_with_messages(**kwargs: Any) -> str:
|
||||
return run_crew_tool(crew, messages, **kwargs)
|
||||
|
||||
return run_crew_tool_with_messages
|
||||
|
||||
|
||||
def flush_input():
|
||||
def flush_input() -> None:
|
||||
"""Flush any pending input from the user."""
|
||||
if platform.system() == "Windows":
|
||||
# Windows platform
|
||||
import msvcrt
|
||||
|
||||
while msvcrt.kbhit():
|
||||
msvcrt.getch()
|
||||
while msvcrt.kbhit(): # type: ignore[attr-defined]
|
||||
msvcrt.getch() # type: ignore[attr-defined]
|
||||
else:
|
||||
# Unix-like platforms (Linux, macOS)
|
||||
import termios
|
||||
@@ -186,7 +187,12 @@ def flush_input():
|
||||
termios.tcflush(sys.stdin, termios.TCIFLUSH)
|
||||
|
||||
|
||||
def chat_loop(chat_llm, messages, crew_tool_schema, available_functions):
|
||||
def chat_loop(
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
) -> None:
|
||||
"""Main chat loop for interacting with the user."""
|
||||
while True:
|
||||
try:
|
||||
@@ -225,7 +231,7 @@ def get_user_input() -> str:
|
||||
|
||||
def handle_user_input(
|
||||
user_input: str,
|
||||
chat_llm: LLM,
|
||||
chat_llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
crew_tool_schema: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
@@ -255,7 +261,7 @@ def handle_user_input(
|
||||
click.secho(f"\nAssistant: {final_response}\n", fg="green")
|
||||
|
||||
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict[str, Any]:
|
||||
"""
|
||||
Dynamically build a Littellm 'function' schema for the given crew.
|
||||
|
||||
@@ -286,7 +292,7 @@ def generate_crew_tool_schema(crew_inputs: ChatInputs) -> dict:
|
||||
}
|
||||
|
||||
|
||||
def run_crew_tool(crew: Crew, messages: list[dict[str, str]], **kwargs):
|
||||
def run_crew_tool(crew: Crew, messages: list[LLMMessage], **kwargs: Any) -> str:
|
||||
"""
|
||||
Runs the crew using crew.kickoff(inputs=kwargs) and returns the output.
|
||||
|
||||
@@ -372,7 +378,9 @@ def load_crew_and_name() -> tuple[Crew, str]:
|
||||
return crew_instance, crew_class_name
|
||||
|
||||
|
||||
def generate_crew_chat_inputs(crew: Crew, crew_name: str, chat_llm) -> ChatInputs:
|
||||
def generate_crew_chat_inputs(
|
||||
crew: Crew, crew_name: str, chat_llm: LLM | BaseLLM
|
||||
) -> ChatInputs:
|
||||
"""
|
||||
Generates the ChatInputs required for the crew by analyzing the tasks and agents.
|
||||
|
||||
@@ -410,23 +418,12 @@ def fetch_required_inputs(crew: Crew) -> set[str]:
|
||||
Returns:
|
||||
Set[str]: A set of placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks
|
||||
for task in crew.tasks:
|
||||
text = f"{task.description or ''} {task.expected_output or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
# Scan agents
|
||||
for agent in crew.agents:
|
||||
text = f"{agent.role or ''} {agent.goal or ''} {agent.backstory or ''}"
|
||||
required_inputs.update(placeholder_pattern.findall(text))
|
||||
|
||||
return required_inputs
|
||||
return crew.fetch_inputs()
|
||||
|
||||
|
||||
def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) -> str:
|
||||
def generate_input_description_with_ai(
|
||||
input_name: str, crew: Crew, chat_llm: LLM | BaseLLM
|
||||
) -> str:
|
||||
"""
|
||||
Generates an input description using AI based on the context of the crew.
|
||||
|
||||
@@ -484,10 +481,10 @@ def generate_input_description_with_ai(input_name: str, crew: Crew, chat_llm) ->
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
def generate_crew_description_with_ai(crew: Crew, chat_llm: LLM | BaseLLM) -> str:
|
||||
"""
|
||||
Generates a brief description of the crew using AI.
|
||||
|
||||
@@ -534,4 +531,4 @@ def generate_crew_description_with_ai(crew: Crew, chat_llm) -> str:
|
||||
f"{context}"
|
||||
)
|
||||
response = chat_llm.call(messages=[{"role": "user", "content": prompt}])
|
||||
return response.strip()
|
||||
return str(response).strip()
|
||||
|
||||
@@ -3,103 +3,56 @@ import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
import sys
|
||||
from typing import BinaryIO, cast
|
||||
import tempfile
|
||||
from typing import Final, Literal, cast
|
||||
|
||||
from cryptography.fernet import Fernet
|
||||
|
||||
|
||||
if sys.platform == "win32":
|
||||
import msvcrt
|
||||
else:
|
||||
import fcntl
|
||||
_FERNET_KEY_LENGTH: Final[Literal[44]] = 44
|
||||
|
||||
|
||||
class TokenManager:
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""
|
||||
Initialize the TokenManager class.
|
||||
"""Manages encrypted token storage."""
|
||||
|
||||
:param file_path: The file path to store the encrypted tokens. Default is "tokens.enc".
|
||||
def __init__(self, file_path: str = "tokens.enc") -> None:
|
||||
"""Initialize the TokenManager.
|
||||
|
||||
Args:
|
||||
file_path: The file path to store encrypted tokens.
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.key = self._get_or_create_key()
|
||||
self.fernet = Fernet(self.key)
|
||||
|
||||
@staticmethod
|
||||
def _acquire_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Acquire an exclusive lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to lock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_LOCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_EX)
|
||||
|
||||
@staticmethod
|
||||
def _release_lock(file_handle: BinaryIO) -> None:
|
||||
"""
|
||||
Release the lock on a file handle.
|
||||
|
||||
Args:
|
||||
file_handle: Open file handle to unlock.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
msvcrt.locking(file_handle.fileno(), msvcrt.LK_UNLCK, 1)
|
||||
else:
|
||||
fcntl.flock(file_handle.fileno(), fcntl.LOCK_UN)
|
||||
|
||||
def _get_or_create_key(self) -> bytes:
|
||||
"""
|
||||
Get or create the encryption key with file locking to prevent race conditions.
|
||||
"""Get or create the encryption key.
|
||||
|
||||
Returns:
|
||||
The encryption key.
|
||||
The encryption key as bytes.
|
||||
"""
|
||||
key_filename = "secret.key"
|
||||
storage_path = self.get_secure_storage_path()
|
||||
key_filename: str = "secret.key"
|
||||
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
lock_file_path = storage_path / f"{key_filename}.lock"
|
||||
|
||||
try:
|
||||
lock_file_path.touch()
|
||||
|
||||
with open(lock_file_path, "r+b") as lock_file:
|
||||
self._acquire_lock(lock_file)
|
||||
try:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
return new_key
|
||||
finally:
|
||||
try:
|
||||
self._release_lock(lock_file)
|
||||
except OSError:
|
||||
pass
|
||||
except OSError:
|
||||
key = self.read_secure_file(key_filename)
|
||||
if key is not None and len(key) == 44:
|
||||
return key
|
||||
|
||||
new_key = Fernet.generate_key()
|
||||
self.save_secure_file(key_filename, new_key)
|
||||
new_key = Fernet.generate_key()
|
||||
if self._atomic_create_secure_file(key_filename, new_key):
|
||||
return new_key
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""
|
||||
Save the access token and its expiration time.
|
||||
key = self._read_secure_file(key_filename)
|
||||
if key is not None and len(key) == _FERNET_KEY_LENGTH:
|
||||
return key
|
||||
|
||||
:param access_token: The access token to save.
|
||||
:param expires_at: The UNIX timestamp of the expiration time.
|
||||
raise RuntimeError("Failed to create or read encryption key")
|
||||
|
||||
def save_tokens(self, access_token: str, expires_at: int) -> None:
|
||||
"""Save the access token and its expiration time.
|
||||
|
||||
Args:
|
||||
access_token: The access token to save.
|
||||
expires_at: The UNIX timestamp of the expiration time.
|
||||
"""
|
||||
expiration_time = datetime.fromtimestamp(expires_at)
|
||||
data = {
|
||||
@@ -107,15 +60,15 @@ class TokenManager:
|
||||
"expiration": expiration_time.isoformat(),
|
||||
}
|
||||
encrypted_data = self.fernet.encrypt(json.dumps(data).encode())
|
||||
self.save_secure_file(self.file_path, encrypted_data)
|
||||
self._atomic_write_secure_file(self.file_path, encrypted_data)
|
||||
|
||||
def get_token(self) -> str | None:
|
||||
"""
|
||||
Get the access token if it is valid and not expired.
|
||||
"""Get the access token if it is valid and not expired.
|
||||
|
||||
:return: The access token if valid and not expired, otherwise None.
|
||||
Returns:
|
||||
The access token if valid and not expired, otherwise None.
|
||||
"""
|
||||
encrypted_data = self.read_secure_file(self.file_path)
|
||||
encrypted_data = self._read_secure_file(self.file_path)
|
||||
if encrypted_data is None:
|
||||
return None
|
||||
|
||||
@@ -126,20 +79,18 @@ class TokenManager:
|
||||
if expiration <= datetime.now():
|
||||
return None
|
||||
|
||||
return cast(str | None, data["access_token"])
|
||||
return cast(str | None, data.get("access_token"))
|
||||
|
||||
def clear_tokens(self) -> None:
|
||||
"""
|
||||
Clear the tokens.
|
||||
"""
|
||||
self.delete_secure_file(self.file_path)
|
||||
"""Clear the stored tokens."""
|
||||
self._delete_secure_file(self.file_path)
|
||||
|
||||
@staticmethod
|
||||
def get_secure_storage_path() -> Path:
|
||||
"""
|
||||
Get the secure storage path based on the operating system.
|
||||
def _get_secure_storage_path() -> Path:
|
||||
"""Get the secure storage path based on the operating system.
|
||||
|
||||
:return: The secure storage path.
|
||||
Returns:
|
||||
The secure storage path.
|
||||
"""
|
||||
if sys.platform == "win32":
|
||||
base_path = os.environ.get("LOCALAPPDATA")
|
||||
@@ -155,44 +106,81 @@ class TokenManager:
|
||||
|
||||
return storage_path
|
||||
|
||||
def save_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""
|
||||
Save the content to a secure file.
|
||||
def _atomic_create_secure_file(self, filename: str, content: bytes) -> bool:
|
||||
"""Create a file only if it doesn't exist.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:param content: The content to save.
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
|
||||
Returns:
|
||||
True if file was created, False if it already exists.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
with open(file_path, "wb") as f:
|
||||
f.write(content)
|
||||
try:
|
||||
fd = os.open(file_path, os.O_CREAT | os.O_EXCL | os.O_WRONLY, 0o600)
|
||||
try:
|
||||
os.write(fd, content)
|
||||
finally:
|
||||
os.close(fd)
|
||||
return True
|
||||
except FileExistsError:
|
||||
return False
|
||||
|
||||
os.chmod(file_path, 0o600)
|
||||
def _atomic_write_secure_file(self, filename: str, content: bytes) -> None:
|
||||
"""Write content to a secure file.
|
||||
|
||||
def read_secure_file(self, filename: str) -> bytes | None:
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
content: The content to write.
|
||||
"""
|
||||
Read the content of a secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
:return: The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
if not file_path.exists():
|
||||
fd, temp_path = tempfile.mkstemp(dir=storage_path, prefix=f".{filename}.")
|
||||
fd_closed = False
|
||||
try:
|
||||
os.write(fd, content)
|
||||
os.close(fd)
|
||||
fd_closed = True
|
||||
os.chmod(temp_path, 0o600)
|
||||
os.replace(temp_path, file_path)
|
||||
except Exception:
|
||||
if not fd_closed:
|
||||
os.close(fd)
|
||||
if os.path.exists(temp_path):
|
||||
os.unlink(temp_path)
|
||||
raise
|
||||
|
||||
def _read_secure_file(self, filename: str) -> bytes | None:
|
||||
"""Read the content of a secure file.
|
||||
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
|
||||
Returns:
|
||||
The content of the file if it exists, otherwise None.
|
||||
"""
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
|
||||
try:
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
except FileNotFoundError:
|
||||
return None
|
||||
|
||||
with open(file_path, "rb") as f:
|
||||
return f.read()
|
||||
def _delete_secure_file(self, filename: str) -> None:
|
||||
"""Delete a secure file.
|
||||
|
||||
def delete_secure_file(self, filename: str) -> None:
|
||||
Args:
|
||||
filename: The name of the file.
|
||||
"""
|
||||
Delete the secure file.
|
||||
|
||||
:param filename: The name of the file.
|
||||
"""
|
||||
storage_path = self.get_secure_storage_path()
|
||||
storage_path = self._get_secure_storage_path()
|
||||
file_path = storage_path / filename
|
||||
if file_path.exists():
|
||||
file_path.unlink(missing_ok=True)
|
||||
try:
|
||||
file_path.unlink()
|
||||
except FileNotFoundError:
|
||||
pass
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.7.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -5,7 +5,7 @@ description = "{{name}} using crewAI"
|
||||
authors = [{ name = "Your Name", email = "you@example.com" }]
|
||||
requires-python = ">=3.10,<3.14"
|
||||
dependencies = [
|
||||
"crewai[tools]==1.6.1"
|
||||
"crewai[tools]==1.7.0"
|
||||
]
|
||||
|
||||
[project.scripts]
|
||||
|
||||
@@ -35,6 +35,14 @@ from crewai.agent import Agent
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.agents.cache.cache_handler import CacheHandler
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.crews.utils import (
|
||||
StreamingContext,
|
||||
check_conditional_skip,
|
||||
enable_agent_streaming,
|
||||
prepare_kickoff,
|
||||
prepare_task_execution,
|
||||
run_for_each_async,
|
||||
)
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.event_listener import EventListener
|
||||
from crewai.events.listeners.tracing.trace_listener import (
|
||||
@@ -47,7 +55,6 @@ from crewai.events.listeners.tracing.utils import (
|
||||
from crewai.events.types.crew_events import (
|
||||
CrewKickoffCompletedEvent,
|
||||
CrewKickoffFailedEvent,
|
||||
CrewKickoffStartedEvent,
|
||||
CrewTestCompletedEvent,
|
||||
CrewTestFailedEvent,
|
||||
CrewTestStartedEvent,
|
||||
@@ -74,7 +81,7 @@ from crewai.tasks.conditional_task import ConditionalTask
|
||||
from crewai.tasks.task_output import TaskOutput
|
||||
from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.types.streaming import CrewStreamingOutput
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.constants import NOT_SPECIFIED, TRAINING_DATA_FILE
|
||||
from crewai.utilities.crew.models import CrewContext
|
||||
@@ -92,10 +99,8 @@ from crewai.utilities.planning_handler import CrewPlanner
|
||||
from crewai.utilities.printer import PrinterColor
|
||||
from crewai.utilities.rpm_controller import RPMController
|
||||
from crewai.utilities.streaming import (
|
||||
TaskInfo,
|
||||
create_async_chunk_generator,
|
||||
create_chunk_generator,
|
||||
create_streaming_state,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
@@ -268,7 +273,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
description="list of file paths for task execution JSON files.",
|
||||
)
|
||||
execution_logs: list[dict[str, Any]] = Field(
|
||||
default=[],
|
||||
default_factory=list,
|
||||
description="list of execution logs for tasks",
|
||||
)
|
||||
knowledge_sources: list[BaseKnowledgeSource] | None = Field(
|
||||
@@ -348,12 +353,12 @@ class Crew(FlowTrackable, BaseModel):
|
||||
return self
|
||||
|
||||
def _initialize_default_memories(self) -> None:
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory() # type: ignore[no-untyped-call]
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory( # type: ignore[no-untyped-call]
|
||||
self._long_term_memory = self._long_term_memory or LongTermMemory()
|
||||
self._short_term_memory = self._short_term_memory or ShortTermMemory(
|
||||
crew=self,
|
||||
embedder_config=self.embedder,
|
||||
)
|
||||
self._entity_memory = self.entity_memory or EntityMemory( # type: ignore[no-untyped-call]
|
||||
self._entity_memory = self.entity_memory or EntityMemory(
|
||||
crew=self, embedder_config=self.embedder
|
||||
)
|
||||
|
||||
@@ -404,8 +409,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"missing_manager_llm_or_manager_agent",
|
||||
(
|
||||
"Attribute `manager_llm` or `manager_agent` is required "
|
||||
"when using hierarchical process."
|
||||
"Attribute `manager_llm` or `manager_agent` is required when using hierarchical process."
|
||||
),
|
||||
{},
|
||||
)
|
||||
@@ -511,10 +515,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
raise PydanticCustomError(
|
||||
"invalid_async_conditional_task",
|
||||
(
|
||||
f"Conditional Task: {task.description}, "
|
||||
f"cannot be executed asynchronously."
|
||||
"Conditional Task: {description}, cannot be executed asynchronously."
|
||||
),
|
||||
{},
|
||||
{"description": task.description},
|
||||
)
|
||||
return self
|
||||
|
||||
@@ -675,21 +678,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs: dict[str, Any] | None = None,
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(current_task_info, result_holder)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext()
|
||||
|
||||
def run_crew() -> None:
|
||||
"""Execute the crew and capture the result."""
|
||||
@@ -697,59 +687,28 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self.stream = False
|
||||
crew_result = self.kickoff(inputs=inputs)
|
||||
if isinstance(crew_result, CrewOutput):
|
||||
result_holder.append(crew_result)
|
||||
ctx.result_holder.append(crew_result)
|
||||
except Exception as exc:
|
||||
signal_error(state, exc)
|
||||
signal_error(ctx.state, exc)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state)
|
||||
signal_end(ctx.state)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
sync_iterator=create_chunk_generator(state, run_crew, output_holder)
|
||||
sync_iterator=create_chunk_generator(
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
return streaming_output
|
||||
|
||||
ctx = baggage.set_baggage(
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(ctx)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
try:
|
||||
for before_callback in self.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffStartedEvent(crew_name=self.name, inputs=inputs),
|
||||
)
|
||||
|
||||
# Starts the crew to work on its assigned tasks.
|
||||
self._task_output_handler.reset()
|
||||
self._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
self._inputs = inputs
|
||||
self._interpolate_inputs(inputs)
|
||||
self._set_tasks_callbacks()
|
||||
self._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
for agent in self.agents:
|
||||
agent.crew = self
|
||||
agent.set_knowledge(crew_embedder=self.embedder)
|
||||
# TODO: Create an AgentFunctionCalling protocol for future refactoring
|
||||
if not agent.function_calling_llm: # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
agent.function_calling_llm = self.function_calling_llm # type: ignore # "BaseAgent" has no attribute "function_calling_llm"
|
||||
|
||||
if not agent.step_callback: # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
agent.step_callback = self.step_callback # type: ignore # "BaseAgent" has no attribute "step_callback"
|
||||
|
||||
agent.create_agent_executor()
|
||||
|
||||
if self.planning:
|
||||
self._handle_crew_planning()
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
if self.process == Process.sequential:
|
||||
result = self._run_sequential_process()
|
||||
@@ -814,42 +773,27 @@ class Crew(FlowTrackable, BaseModel):
|
||||
inputs = inputs or {}
|
||||
|
||||
if self.stream:
|
||||
for agent in self.agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
result_holder: list[CrewOutput] = []
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
self.stream = False
|
||||
result = await asyncio.to_thread(self.kickoff, inputs)
|
||||
if isinstance(result, CrewOutput):
|
||||
result_holder.append(result)
|
||||
ctx.result_holder.append(result)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
self.stream = True
|
||||
signal_end(state, is_async=True)
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_crew, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
@@ -864,89 +808,207 @@ class Crew(FlowTrackable, BaseModel):
|
||||
from all crews as they arrive. After iteration, access results via .results
|
||||
(list of CrewOutput).
|
||||
"""
|
||||
crew_copies = [self.copy() for _ in inputs]
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.kickoff_async(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
"""Native async kickoff method using async task execution throughout.
|
||||
|
||||
Unlike kickoff_async which wraps sync kickoff in a thread, this method
|
||||
uses native async/await for all operations including task execution,
|
||||
memory operations, and knowledge queries.
|
||||
"""
|
||||
if self.stream:
|
||||
result_holder: list[list[CrewOutput]] = [[]]
|
||||
current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
enable_agent_streaming(self.agents)
|
||||
ctx = StreamingContext(use_async=True)
|
||||
|
||||
state = create_streaming_state(
|
||||
current_task_info, result_holder, use_async=True
|
||||
)
|
||||
output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
"""Run all crew copies and aggregate their streaming outputs."""
|
||||
async def run_crew() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew in enumerate(crew_copies):
|
||||
streaming = await crew.kickoff_async(inputs=inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
"""Consume stream chunks and forward to parent queue.
|
||||
|
||||
Args:
|
||||
stream_output: The streaming output to consume.
|
||||
|
||||
Returns:
|
||||
The final CrewOutput result.
|
||||
"""
|
||||
async for chunk in stream_output:
|
||||
if state.async_queue is not None and state.loop is not None:
|
||||
state.loop.call_soon_threadsafe(
|
||||
state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(state, e, is_async=True)
|
||||
self.stream = False
|
||||
inner_result = await self.akickoff(inputs)
|
||||
if isinstance(inner_result, CrewOutput):
|
||||
ctx.result_holder.append(inner_result)
|
||||
except Exception as exc:
|
||||
signal_error(ctx.state, exc, is_async=True)
|
||||
finally:
|
||||
signal_end(state, is_async=True)
|
||||
self.stream = True
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
state, run_all_crews, output_holder
|
||||
ctx.state, run_crew, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
"""Wrap _set_results to match _set_result signature."""
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
output_holder.append(streaming_output)
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
tasks = [
|
||||
asyncio.create_task(crew_copy.kickoff_async(inputs=input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
baggage_ctx = baggage.set_baggage(
|
||||
"crew_context", CrewContext(id=str(self.id), key=self.key)
|
||||
)
|
||||
token = attach(baggage_ctx)
|
||||
|
||||
results = await asyncio.gather(*tasks)
|
||||
try:
|
||||
inputs = prepare_kickoff(self, inputs)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
self.usage_metrics = total_usage_metrics
|
||||
if self.process == Process.sequential:
|
||||
result = await self._arun_sequential_process()
|
||||
elif self.process == Process.hierarchical:
|
||||
result = await self._arun_hierarchical_process()
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
f"The process '{self.process}' is not implemented yet."
|
||||
)
|
||||
|
||||
self._task_output_handler.reset()
|
||||
return list(results)
|
||||
for after_callback in self.after_kickoff_callbacks:
|
||||
result = after_callback(result)
|
||||
|
||||
self.usage_metrics = self.calculate_usage_metrics()
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
CrewKickoffFailedEvent(error=str(e), crew_name=self.name),
|
||||
)
|
||||
raise
|
||||
finally:
|
||||
detach(token)
|
||||
|
||||
async def akickoff_for_each(
|
||||
self, inputs: list[dict[str, Any]]
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Native async execution of the Crew's workflow for each input.
|
||||
|
||||
Uses native async throughout rather than thread-based async.
|
||||
If stream=True, returns a single CrewStreamingOutput that yields chunks
|
||||
from all crews as they arrive.
|
||||
"""
|
||||
|
||||
async def kickoff_fn(
|
||||
crew: Crew, input_data: dict[str, Any]
|
||||
) -> CrewOutput | CrewStreamingOutput:
|
||||
return await crew.akickoff(inputs=input_data)
|
||||
|
||||
return await run_for_each_async(self, inputs, kickoff_fn)
|
||||
|
||||
async def _arun_sequential_process(self) -> CrewOutput:
|
||||
"""Executes tasks sequentially using native async and returns the final output."""
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _arun_hierarchical_process(self) -> CrewOutput:
|
||||
"""Creates and assigns a manager agent to complete the tasks using native async."""
|
||||
self._create_manager_agent()
|
||||
return await self._aexecute_tasks(self.tasks)
|
||||
|
||||
async def _aexecute_tasks(
|
||||
self,
|
||||
tasks: list[Task],
|
||||
start_index: int | None = 0,
|
||||
was_replayed: bool = False,
|
||||
) -> CrewOutput:
|
||||
"""Executes tasks using native async and returns the final output.
|
||||
|
||||
Args:
|
||||
tasks: List of tasks to execute
|
||||
start_index: Index to start execution from (for replay)
|
||||
was_replayed: Whether this is a replayed execution
|
||||
|
||||
Returns:
|
||||
CrewOutput: Final output of the crew
|
||||
"""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]] = []
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = await self._ahandle_conditional_task(
|
||||
task, task_outputs, pending_tasks, task_index, was_replayed
|
||||
)
|
||||
if skipped_task_output:
|
||||
task_outputs.append(skipped_task_output)
|
||||
continue
|
||||
|
||||
if task.async_execution:
|
||||
context = self._get_context(
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
async_task = asyncio.create_task(
|
||||
task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
)
|
||||
pending_tasks.append((task, async_task, task_index))
|
||||
else:
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(
|
||||
pending_tasks, was_replayed
|
||||
)
|
||||
pending_tasks.clear()
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = await task.aexecute_sync(
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
self._store_execution_log(task, task_output, task_index, was_replayed)
|
||||
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
|
||||
return self._create_crew_output(task_outputs)
|
||||
|
||||
async def _ahandle_conditional_task(
|
||||
self,
|
||||
task: ConditionalTask,
|
||||
task_outputs: list[TaskOutput],
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> TaskOutput | None:
|
||||
"""Handle conditional task evaluation using native async."""
|
||||
if pending_tasks:
|
||||
task_outputs = await self._aprocess_async_tasks(pending_tasks, was_replayed)
|
||||
pending_tasks.clear()
|
||||
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
async def _aprocess_async_tasks(
|
||||
self,
|
||||
pending_tasks: list[tuple[Task, asyncio.Task[TaskOutput], int]],
|
||||
was_replayed: bool = False,
|
||||
) -> list[TaskOutput]:
|
||||
"""Process pending async tasks and return their outputs."""
|
||||
task_outputs: list[TaskOutput] = []
|
||||
for future_task, async_task, task_index in pending_tasks:
|
||||
task_output = await async_task
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(future_task, task_output)
|
||||
self._store_execution_log(
|
||||
future_task, task_output, task_index, was_replayed
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _handle_crew_planning(self) -> None:
|
||||
"""Handles the Crew planning."""
|
||||
@@ -1064,33 +1126,11 @@ class Crew(FlowTrackable, BaseModel):
|
||||
last_sync_output: TaskOutput | None = None
|
||||
|
||||
for task_index, task in enumerate(tasks):
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
continue
|
||||
|
||||
agent_to_use = self._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
# Determine which tools to use - task tools take precedence over agent tools
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
# Prepare tools and ensure they're compatible with task execution
|
||||
tools_for_task = self._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
exec_data, task_outputs, last_sync_output = prepare_task_execution(
|
||||
self, task, task_index, start_index, task_outputs, last_sync_output
|
||||
)
|
||||
|
||||
self._log_task_start(task, agent_to_use.role)
|
||||
if exec_data.should_skip:
|
||||
continue
|
||||
|
||||
if isinstance(task, ConditionalTask):
|
||||
skipped_task_output = self._handle_conditional_task(
|
||||
@@ -1105,9 +1145,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task, [last_sync_output] if last_sync_output else []
|
||||
)
|
||||
future = task.execute_async(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
futures.append((task, future, task_index))
|
||||
else:
|
||||
@@ -1117,9 +1157,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
context = self._get_context(task, task_outputs)
|
||||
task_output = task.execute_sync(
|
||||
agent=agent_to_use,
|
||||
agent=exec_data.agent,
|
||||
context=context,
|
||||
tools=tools_for_task,
|
||||
tools=exec_data.tools,
|
||||
)
|
||||
task_outputs.append(task_output)
|
||||
self._process_task_result(task, task_output)
|
||||
@@ -1142,19 +1182,9 @@ class Crew(FlowTrackable, BaseModel):
|
||||
task_outputs = self._process_async_tasks(futures, was_replayed)
|
||||
futures.clear()
|
||||
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
self._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
self._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
return check_conditional_skip(
|
||||
self, task, task_outputs, task_index, was_replayed
|
||||
)
|
||||
|
||||
def _prepare_tools(
|
||||
self, agent: BaseAgent, task: Task, tools: list[BaseTool]
|
||||
@@ -1318,7 +1348,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return tools
|
||||
|
||||
def _get_context(self, task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
@staticmethod
|
||||
def _get_context(task: Task, task_outputs: list[TaskOutput]) -> str:
|
||||
if not task.context:
|
||||
return ""
|
||||
|
||||
@@ -1387,7 +1418,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return task_outputs
|
||||
|
||||
def _find_task_index(self, task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
@staticmethod
|
||||
def _find_task_index(task_id: str, stored_outputs: list[Any]) -> int | None:
|
||||
return next(
|
||||
(
|
||||
index
|
||||
@@ -1447,6 +1479,16 @@ class Crew(FlowTrackable, BaseModel):
|
||||
)
|
||||
return None
|
||||
|
||||
async def aquery_knowledge(
|
||||
self, query: list[str], results_limit: int = 3, score_threshold: float = 0.35
|
||||
) -> list[SearchResult] | None:
|
||||
"""Query the crew's knowledge base for relevant information asynchronously."""
|
||||
if self.knowledge:
|
||||
return await self.knowledge.aquery(
|
||||
query, results_limit=results_limit, score_threshold=score_threshold
|
||||
)
|
||||
return None
|
||||
|
||||
def fetch_inputs(self) -> set[str]:
|
||||
"""
|
||||
Gathers placeholders (e.g., {something}) referenced in tasks or agents.
|
||||
@@ -1455,7 +1497,7 @@ class Crew(FlowTrackable, BaseModel):
|
||||
|
||||
Returns a set of all discovered placeholder names.
|
||||
"""
|
||||
placeholder_pattern = re.compile(r"\{(.+?)\}")
|
||||
placeholder_pattern = re.compile(r"\{(.+?)}")
|
||||
required_inputs: set[str] = set()
|
||||
|
||||
# Scan tasks for inputs
|
||||
@@ -1703,6 +1745,32 @@ class Crew(FlowTrackable, BaseModel):
|
||||
self._logger.log("error", error_msg)
|
||||
raise RuntimeError(error_msg) from e
|
||||
|
||||
def _reset_memory_system(
|
||||
self, system: Any, name: str, reset_fn: Callable[[Any], Any]
|
||||
) -> None:
|
||||
"""Reset a single memory system.
|
||||
|
||||
Args:
|
||||
system: The memory system instance to reset.
|
||||
name: Display name of the memory system for logging.
|
||||
reset_fn: Function to call to reset the system.
|
||||
|
||||
Raises:
|
||||
RuntimeError: If the reset operation fails.
|
||||
"""
|
||||
try:
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
|
||||
def _reset_all_memories(self) -> None:
|
||||
"""Reset all available memory systems."""
|
||||
memory_systems = self._get_memory_systems()
|
||||
@@ -1710,21 +1778,10 @@ class Crew(FlowTrackable, BaseModel):
|
||||
for config in memory_systems.values():
|
||||
if (system := config.get("system")) is not None:
|
||||
name = config.get("name")
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _reset_specific_memory(self, memory_type: str) -> None:
|
||||
"""Reset a specific memory system.
|
||||
@@ -1743,21 +1800,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
if system is None:
|
||||
raise RuntimeError(f"{name} memory system is not initialized")
|
||||
|
||||
try:
|
||||
reset_fn: Callable[[Any], Any] = cast(
|
||||
Callable[[Any], Any], config.get("reset")
|
||||
)
|
||||
reset_fn(system)
|
||||
self._logger.log(
|
||||
"info",
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"{name} memory has been reset",
|
||||
)
|
||||
except Exception as e:
|
||||
raise RuntimeError(
|
||||
f"[Crew ({self.name if self.name else self.id})] "
|
||||
f"Failed to reset {name} memory: {e!s}"
|
||||
) from e
|
||||
reset_fn: Callable[[Any], Any] = cast(Callable[[Any], Any], config.get("reset"))
|
||||
self._reset_memory_system(system, name, reset_fn)
|
||||
|
||||
def _get_memory_systems(self) -> dict[str, Any]:
|
||||
"""Get all available memory systems with their configuration.
|
||||
@@ -1845,7 +1889,8 @@ class Crew(FlowTrackable, BaseModel):
|
||||
):
|
||||
self.tasks[0].allow_crewai_trigger_context = True
|
||||
|
||||
def _show_tracing_disabled_message(self) -> None:
|
||||
@staticmethod
|
||||
def _show_tracing_disabled_message() -> None:
|
||||
"""Show a message when tracing is disabled."""
|
||||
from crewai.events.listeners.tracing.utils import has_user_declined_tracing
|
||||
|
||||
|
||||
363
lib/crewai/src/crewai/crews/utils.py
Normal file
363
lib/crewai/src/crewai/crews/utils.py
Normal file
@@ -0,0 +1,363 @@
|
||||
"""Utility functions for crew operations."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from collections.abc import Callable, Coroutine, Iterable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
from crewai.agents.agent_builder.base_agent import BaseAgent
|
||||
from crewai.crews.crew_output import CrewOutput
|
||||
from crewai.rag.embeddings.types import EmbedderConfig
|
||||
from crewai.types.streaming import CrewStreamingOutput, FlowStreamingOutput
|
||||
from crewai.utilities.streaming import (
|
||||
StreamingState,
|
||||
TaskInfo,
|
||||
create_streaming_state,
|
||||
)
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.crew import Crew
|
||||
|
||||
|
||||
def enable_agent_streaming(agents: Iterable[BaseAgent]) -> None:
|
||||
"""Enable streaming on all agents that have an LLM configured.
|
||||
|
||||
Args:
|
||||
agents: Iterable of agents to enable streaming on.
|
||||
"""
|
||||
for agent in agents:
|
||||
if agent.llm is not None:
|
||||
agent.llm.stream = True
|
||||
|
||||
|
||||
def setup_agents(
|
||||
crew: Crew,
|
||||
agents: Iterable[BaseAgent],
|
||||
embedder: EmbedderConfig | None,
|
||||
function_calling_llm: Any,
|
||||
step_callback: Callable[..., Any] | None,
|
||||
) -> None:
|
||||
"""Set up agents for crew execution.
|
||||
|
||||
Args:
|
||||
crew: The crew instance agents belong to.
|
||||
agents: Iterable of agents to set up.
|
||||
embedder: Embedder configuration for knowledge.
|
||||
function_calling_llm: Default function calling LLM for agents.
|
||||
step_callback: Default step callback for agents.
|
||||
"""
|
||||
for agent in agents:
|
||||
agent.crew = crew
|
||||
agent.set_knowledge(crew_embedder=embedder)
|
||||
if not agent.function_calling_llm: # type: ignore[attr-defined]
|
||||
agent.function_calling_llm = function_calling_llm # type: ignore[attr-defined]
|
||||
if not agent.step_callback: # type: ignore[attr-defined]
|
||||
agent.step_callback = step_callback # type: ignore[attr-defined]
|
||||
agent.create_agent_executor()
|
||||
|
||||
|
||||
class TaskExecutionData:
|
||||
"""Data container for prepared task execution information."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
tools: list[Any],
|
||||
should_skip: bool = False,
|
||||
) -> None:
|
||||
"""Initialize task execution data.
|
||||
|
||||
Args:
|
||||
agent: The agent to use for task execution (None if skipped).
|
||||
tools: Prepared tools for the task.
|
||||
should_skip: Whether the task should be skipped (replay).
|
||||
"""
|
||||
self.agent = agent
|
||||
self.tools = tools
|
||||
self.should_skip = should_skip
|
||||
|
||||
|
||||
def prepare_task_execution(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_index: int,
|
||||
start_index: int | None,
|
||||
task_outputs: list[Any],
|
||||
last_sync_output: Any | None,
|
||||
) -> tuple[TaskExecutionData, list[Any], Any | None]:
|
||||
"""Prepare a task for execution, handling replay skip logic and agent/tool setup.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The task to prepare.
|
||||
task_index: Index of the current task.
|
||||
start_index: Index to start execution from (for replay).
|
||||
task_outputs: Current list of task outputs.
|
||||
last_sync_output: Last synchronous task output.
|
||||
|
||||
Returns:
|
||||
A tuple of (TaskExecutionData or None if skipped, updated task_outputs, updated last_sync_output).
|
||||
If the task should be skipped, TaskExecutionData will have should_skip=True.
|
||||
|
||||
Raises:
|
||||
ValueError: If no agent is available for the task.
|
||||
"""
|
||||
# Handle replay skip
|
||||
if start_index is not None and task_index < start_index:
|
||||
if task.output:
|
||||
if task.async_execution:
|
||||
task_outputs.append(task.output)
|
||||
else:
|
||||
task_outputs = [task.output]
|
||||
last_sync_output = task.output
|
||||
return (
|
||||
TaskExecutionData(agent=None, tools=[], should_skip=True),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
agent_to_use = crew._get_agent_to_use(task)
|
||||
if agent_to_use is None:
|
||||
raise ValueError(
|
||||
f"No agent available for task: {task.description}. "
|
||||
f"Ensure that either the task has an assigned agent "
|
||||
f"or a manager agent is provided."
|
||||
)
|
||||
|
||||
tools_for_task = task.tools or agent_to_use.tools or []
|
||||
tools_for_task = crew._prepare_tools(
|
||||
agent_to_use,
|
||||
task,
|
||||
tools_for_task,
|
||||
)
|
||||
|
||||
crew._log_task_start(task, agent_to_use.role)
|
||||
|
||||
return (
|
||||
TaskExecutionData(agent=agent_to_use, tools=tools_for_task),
|
||||
task_outputs,
|
||||
last_sync_output,
|
||||
)
|
||||
|
||||
|
||||
def check_conditional_skip(
|
||||
crew: Crew,
|
||||
task: Any,
|
||||
task_outputs: list[Any],
|
||||
task_index: int,
|
||||
was_replayed: bool,
|
||||
) -> Any | None:
|
||||
"""Check if a conditional task should be skipped.
|
||||
|
||||
Args:
|
||||
crew: The crew instance.
|
||||
task: The conditional task to check.
|
||||
task_outputs: List of previous task outputs.
|
||||
task_index: Index of the current task.
|
||||
was_replayed: Whether this is a replayed execution.
|
||||
|
||||
Returns:
|
||||
The skipped task output if the task should be skipped, None otherwise.
|
||||
"""
|
||||
previous_output = task_outputs[-1] if task_outputs else None
|
||||
if previous_output is not None and not task.should_execute(previous_output):
|
||||
crew._logger.log(
|
||||
"debug",
|
||||
f"Skipping conditional task: {task.description}",
|
||||
color="yellow",
|
||||
)
|
||||
skipped_task_output = task.get_skipped_task_output()
|
||||
|
||||
if not was_replayed:
|
||||
crew._store_execution_log(task, skipped_task_output, task_index)
|
||||
return skipped_task_output
|
||||
return None
|
||||
|
||||
|
||||
def prepare_kickoff(crew: Crew, inputs: dict[str, Any] | None) -> dict[str, Any] | None:
|
||||
"""Prepare crew for kickoff execution.
|
||||
|
||||
Handles before callbacks, event emission, task handler reset, input
|
||||
interpolation, task callbacks, agent setup, and planning.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to prepare.
|
||||
inputs: Optional input dictionary to pass to the crew.
|
||||
|
||||
Returns:
|
||||
The potentially modified inputs dictionary after before callbacks.
|
||||
"""
|
||||
from crewai.events.event_bus import crewai_event_bus
|
||||
from crewai.events.types.crew_events import CrewKickoffStartedEvent
|
||||
|
||||
for before_callback in crew.before_kickoff_callbacks:
|
||||
if inputs is None:
|
||||
inputs = {}
|
||||
inputs = before_callback(inputs)
|
||||
|
||||
future = crewai_event_bus.emit(
|
||||
crew,
|
||||
CrewKickoffStartedEvent(crew_name=crew.name, inputs=inputs),
|
||||
)
|
||||
if future is not None:
|
||||
try:
|
||||
future.result()
|
||||
except Exception: # noqa: S110
|
||||
pass
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
crew._logging_color = "bold_purple"
|
||||
|
||||
if inputs is not None:
|
||||
crew._inputs = inputs
|
||||
crew._interpolate_inputs(inputs)
|
||||
crew._set_tasks_callbacks()
|
||||
crew._set_allow_crewai_trigger_context_for_first_task()
|
||||
|
||||
setup_agents(
|
||||
crew,
|
||||
crew.agents,
|
||||
crew.embedder,
|
||||
crew.function_calling_llm,
|
||||
crew.step_callback,
|
||||
)
|
||||
|
||||
if crew.planning:
|
||||
crew._handle_crew_planning()
|
||||
|
||||
return inputs
|
||||
|
||||
|
||||
class StreamingContext:
|
||||
"""Container for streaming state and holders used during crew execution."""
|
||||
|
||||
def __init__(self, use_async: bool = False) -> None:
|
||||
"""Initialize streaming context.
|
||||
|
||||
Args:
|
||||
use_async: Whether to use async streaming mode.
|
||||
"""
|
||||
self.result_holder: list[CrewOutput] = []
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=use_async
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
class ForEachStreamingContext:
|
||||
"""Container for streaming state used in for_each crew execution methods."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize for_each streaming context."""
|
||||
self.result_holder: list[list[CrewOutput]] = [[]]
|
||||
self.current_task_info: TaskInfo = {
|
||||
"index": 0,
|
||||
"name": "",
|
||||
"id": "",
|
||||
"agent_role": "",
|
||||
"agent_id": "",
|
||||
}
|
||||
self.state: StreamingState = create_streaming_state(
|
||||
self.current_task_info, self.result_holder, use_async=True
|
||||
)
|
||||
self.output_holder: list[CrewStreamingOutput | FlowStreamingOutput] = []
|
||||
|
||||
|
||||
async def run_for_each_async(
|
||||
crew: Crew,
|
||||
inputs: list[dict[str, Any]],
|
||||
kickoff_fn: Callable[
|
||||
[Crew, dict[str, Any]], Coroutine[Any, Any, CrewOutput | CrewStreamingOutput]
|
||||
],
|
||||
) -> list[CrewOutput | CrewStreamingOutput] | CrewStreamingOutput:
|
||||
"""Execute crew workflow for each input asynchronously.
|
||||
|
||||
Args:
|
||||
crew: The crew instance to execute.
|
||||
inputs: List of input dictionaries for each execution.
|
||||
kickoff_fn: Async function to call for each crew copy (kickoff_async or akickoff).
|
||||
|
||||
Returns:
|
||||
If streaming, a single CrewStreamingOutput that yields chunks from all crews.
|
||||
Otherwise, a list of CrewOutput results.
|
||||
"""
|
||||
from crewai.types.usage_metrics import UsageMetrics
|
||||
from crewai.utilities.streaming import (
|
||||
create_async_chunk_generator,
|
||||
signal_end,
|
||||
signal_error,
|
||||
)
|
||||
|
||||
crew_copies = [crew.copy() for _ in inputs]
|
||||
|
||||
if crew.stream:
|
||||
ctx = ForEachStreamingContext()
|
||||
|
||||
async def run_all_crews() -> None:
|
||||
try:
|
||||
streaming_outputs: list[CrewStreamingOutput] = []
|
||||
for i, crew_copy in enumerate(crew_copies):
|
||||
streaming = await kickoff_fn(crew_copy, inputs[i])
|
||||
if isinstance(streaming, CrewStreamingOutput):
|
||||
streaming_outputs.append(streaming)
|
||||
|
||||
async def consume_stream(
|
||||
stream_output: CrewStreamingOutput,
|
||||
) -> CrewOutput:
|
||||
async for chunk in stream_output:
|
||||
if (
|
||||
ctx.state.async_queue is not None
|
||||
and ctx.state.loop is not None
|
||||
):
|
||||
ctx.state.loop.call_soon_threadsafe(
|
||||
ctx.state.async_queue.put_nowait, chunk
|
||||
)
|
||||
return stream_output.result
|
||||
|
||||
crew_results = await asyncio.gather(
|
||||
*[consume_stream(s) for s in streaming_outputs]
|
||||
)
|
||||
ctx.result_holder[0] = list(crew_results)
|
||||
except Exception as e:
|
||||
signal_error(ctx.state, e, is_async=True)
|
||||
finally:
|
||||
signal_end(ctx.state, is_async=True)
|
||||
|
||||
streaming_output = CrewStreamingOutput(
|
||||
async_iterator=create_async_chunk_generator(
|
||||
ctx.state, run_all_crews, ctx.output_holder
|
||||
)
|
||||
)
|
||||
|
||||
def set_results_wrapper(result: Any) -> None:
|
||||
streaming_output._set_results(result)
|
||||
|
||||
streaming_output._set_result = set_results_wrapper # type: ignore[method-assign]
|
||||
ctx.output_holder.append(streaming_output)
|
||||
|
||||
return streaming_output
|
||||
|
||||
async_tasks: list[asyncio.Task[CrewOutput | CrewStreamingOutput]] = [
|
||||
asyncio.create_task(kickoff_fn(crew_copy, input_data))
|
||||
for crew_copy, input_data in zip(crew_copies, inputs, strict=True)
|
||||
]
|
||||
|
||||
results = await asyncio.gather(*async_tasks)
|
||||
|
||||
total_usage_metrics = UsageMetrics()
|
||||
for crew_copy in crew_copies:
|
||||
if crew_copy.usage_metrics:
|
||||
total_usage_metrics.add_usage_metrics(crew_copy.usage_metrics)
|
||||
crew.usage_metrics = total_usage_metrics
|
||||
|
||||
crew._task_output_handler.reset()
|
||||
return list(results)
|
||||
@@ -140,7 +140,9 @@ class EventListener(BaseEventListener):
|
||||
def on_crew_started(source: Any, event: CrewKickoffStartedEvent) -> None:
|
||||
with self._crew_tree_lock:
|
||||
self.formatter.create_crew_tree(event.crew_name or "Crew", source.id)
|
||||
self._telemetry.crew_execution_span(source, event.inputs)
|
||||
source._execution_span = self._telemetry.crew_execution_span(
|
||||
source, event.inputs
|
||||
)
|
||||
self._crew_tree_lock.notify_all()
|
||||
|
||||
@crewai_event_bus.on(CrewKickoffCompletedEvent)
|
||||
|
||||
@@ -1032,6 +1032,20 @@ class Flow(Generic[T], metaclass=FlowMeta):
|
||||
finally:
|
||||
detach(flow_token)
|
||||
|
||||
async def akickoff(
|
||||
self, inputs: dict[str, Any] | None = None
|
||||
) -> Any | FlowStreamingOutput:
|
||||
"""Native async method to start the flow execution. Alias for kickoff_async.
|
||||
|
||||
|
||||
Args:
|
||||
inputs: Optional dictionary containing input values and/or a state ID for restoration.
|
||||
|
||||
Returns:
|
||||
The final output from the flow, which is the result of the last executed method.
|
||||
"""
|
||||
return await self.kickoff_async(inputs)
|
||||
|
||||
async def _execute_start_method(self, start_method_name: FlowMethodName) -> None:
|
||||
"""Executes a flow's start method and its triggered listeners.
|
||||
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
@@ -9,17 +9,22 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class LLMCallHookContext:
|
||||
"""Context object passed to LLM call hooks with full executor access.
|
||||
"""Context object passed to LLM call hooks.
|
||||
|
||||
Provides hooks with complete access to the executor state, allowing
|
||||
Provides hooks with complete access to the execution state, allowing
|
||||
modification of messages, responses, and executor attributes.
|
||||
|
||||
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
|
||||
|
||||
Attributes:
|
||||
executor: Full reference to the CrewAgentExecutor instance
|
||||
messages: Direct reference to executor.messages (mutable list).
|
||||
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
|
||||
messages: Direct reference to messages (mutable list).
|
||||
Can be modified in both before_llm_call and after_llm_call hooks.
|
||||
Modifications in after_llm_call hooks persist to the next iteration,
|
||||
allowing hooks to modify conversation history for subsequent LLM calls.
|
||||
@@ -27,33 +32,75 @@ class LLMCallHookContext:
|
||||
Do NOT replace the list (e.g., context.messages = []), as this will break
|
||||
the executor. Use context.messages.append() or context.messages.extend()
|
||||
instead of assignment.
|
||||
agent: Reference to the agent executing the task
|
||||
task: Reference to the task being executed
|
||||
crew: Reference to the crew instance
|
||||
agent: Reference to the agent executing the task (None for direct LLM calls)
|
||||
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
|
||||
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
|
||||
llm: Reference to the LLM instance
|
||||
iterations: Current iteration count
|
||||
iterations: Current iteration count (0 for direct LLM calls)
|
||||
response: LLM response string (only set for after_llm_call hooks).
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
executor: CrewAgentExecutor | LiteAgent | None
|
||||
messages: list[LLMMessage]
|
||||
agent: Any
|
||||
task: Any
|
||||
crew: Any
|
||||
llm: BaseLLM | None | str | Any
|
||||
iterations: int
|
||||
response: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
executor: CrewAgentExecutor | LiteAgent | None = None,
|
||||
response: str | None = None,
|
||||
messages: list[LLMMessage] | None = None,
|
||||
llm: BaseLLM | str | Any | None = None, # TODO: look into
|
||||
agent: Any | None = None,
|
||||
task: Any | None = None,
|
||||
crew: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize hook context with executor reference.
|
||||
"""Initialize hook context with executor reference or direct parameters.
|
||||
|
||||
Args:
|
||||
executor: The CrewAgentExecutor instance
|
||||
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
|
||||
response: Optional response string (for after_llm_call hooks)
|
||||
messages: Optional messages list (for direct LLM calls when executor is None)
|
||||
llm: Optional LLM instance (for direct LLM calls when executor is None)
|
||||
agent: Optional agent reference (for direct LLM calls when executor is None)
|
||||
task: Optional task reference (for direct LLM calls when executor is None)
|
||||
crew: Optional crew reference (for direct LLM calls when executor is None)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.agent = executor.agent
|
||||
self.task = executor.task
|
||||
self.crew = executor.crew
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
if executor is not None:
|
||||
# Existing path: extract from executor
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
# Handle CrewAgentExecutor vs LiteAgent differences
|
||||
if hasattr(executor, "agent"):
|
||||
self.agent = executor.agent
|
||||
self.task = cast("CrewAgentExecutor", executor).task
|
||||
self.crew = cast("CrewAgentExecutor", executor).crew
|
||||
else:
|
||||
# LiteAgent case - is the agent itself, doesn't have task/crew
|
||||
self.agent = (
|
||||
executor.original_agent
|
||||
if hasattr(executor, "original_agent")
|
||||
else executor
|
||||
)
|
||||
self.task = None
|
||||
self.crew = None
|
||||
else:
|
||||
# New path: direct LLM call with explicit parameters
|
||||
self.executor = None
|
||||
self.messages = messages or []
|
||||
self.llm = llm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.iterations = 0
|
||||
|
||||
self.response = response
|
||||
|
||||
def request_human_input(
|
||||
|
||||
@@ -32,8 +32,8 @@ class Knowledge(BaseModel):
|
||||
sources: list[BaseKnowledgeSource],
|
||||
embedder: EmbedderConfig | None = None,
|
||||
storage: KnowledgeStorage | None = None,
|
||||
**data,
|
||||
):
|
||||
**data: object,
|
||||
) -> None:
|
||||
super().__init__(**data)
|
||||
if storage:
|
||||
self.storage = storage
|
||||
@@ -75,3 +75,44 @@ class Knowledge(BaseModel):
|
||||
self.storage.reset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
async def aquery(
|
||||
self, query: list[str], results_limit: int = 5, score_threshold: float = 0.6
|
||||
) -> list[SearchResult]:
|
||||
"""Query across all knowledge sources asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
results_limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
The top results matching the query.
|
||||
|
||||
Raises:
|
||||
ValueError: If storage is not initialized.
|
||||
"""
|
||||
if self.storage is None:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
return await self.storage.asearch(
|
||||
query,
|
||||
limit=results_limit,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
|
||||
async def aadd_sources(self) -> None:
|
||||
"""Add all knowledge sources to storage asynchronously."""
|
||||
try:
|
||||
for source in self.sources:
|
||||
source.storage = self.storage
|
||||
await source.aadd()
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.areset()
|
||||
else:
|
||||
raise ValueError("Storage is not initialized.")
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -25,7 +26,10 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -38,7 +42,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
raise ValueError("Either file_path or file_paths must be provided")
|
||||
return v
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to load content."""
|
||||
self.safe_file_paths = self._process_file_paths()
|
||||
self.validate_content()
|
||||
@@ -48,7 +52,7 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
def load_content(self) -> dict[Path, str]:
|
||||
"""Load and preprocess file content. Should be overridden by subclasses. Assume that the file path is relative to the project root in the knowledge directory."""
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -65,13 +69,20 @@ class BaseFileKnowledgeSource(BaseKnowledgeSource, ABC):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def _save_documents(self):
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage."""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously."""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
def convert_to_path(self, path: Path | str) -> Path:
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
@@ -39,12 +39,32 @@ class BaseKnowledgeSource(BaseModel, ABC):
|
||||
for i in range(0, len(text), self.chunk_size - self.chunk_overlap)
|
||||
]
|
||||
|
||||
def _save_documents(self):
|
||||
"""
|
||||
Save the documents to the storage.
|
||||
def _save_documents(self) -> None:
|
||||
"""Save the documents to the storage.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
self.storage.save(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@abstractmethod
|
||||
async def aadd(self) -> None:
|
||||
"""Process content, chunk it, compute embeddings, and save them asynchronously."""
|
||||
|
||||
async def _asave_documents(self) -> None:
|
||||
"""Save the documents to the storage asynchronously.
|
||||
|
||||
This method should be called after the chunks and embeddings are generated.
|
||||
|
||||
Raises:
|
||||
ValueError: If no storage is configured.
|
||||
"""
|
||||
if self.storage:
|
||||
await self.storage.asave(self.chunks)
|
||||
else:
|
||||
raise ValueError("No storage found to save documents.")
|
||||
|
||||
@@ -2,27 +2,24 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Iterator
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from urllib.parse import urlparse
|
||||
|
||||
|
||||
try:
|
||||
from docling.datamodel.base_models import ( # type: ignore[import-not-found]
|
||||
InputFormat,
|
||||
)
|
||||
from docling.document_converter import ( # type: ignore[import-not-found]
|
||||
DocumentConverter,
|
||||
)
|
||||
from docling.exceptions import ConversionError # type: ignore[import-not-found]
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import ( # type: ignore[import-not-found]
|
||||
HierarchicalChunker,
|
||||
)
|
||||
from docling_core.types.doc.document import ( # type: ignore[import-not-found]
|
||||
DoclingDocument,
|
||||
)
|
||||
from docling.datamodel.base_models import InputFormat
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling.exceptions import ConversionError
|
||||
from docling_core.transforms.chunker.hierarchical_chunker import HierarchicalChunker
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
DOCLING_AVAILABLE = True
|
||||
except ImportError:
|
||||
DOCLING_AVAILABLE = False
|
||||
# Provide type stubs for when docling is not available
|
||||
if TYPE_CHECKING:
|
||||
from docling.document_converter import DocumentConverter
|
||||
from docling_core.types.doc.document import DoclingDocument
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
@@ -32,11 +29,13 @@ from crewai.utilities.logger import Logger
|
||||
|
||||
|
||||
class CrewDoclingSource(BaseKnowledgeSource):
|
||||
"""Default Source class for converting documents to markdown or json
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without any additional dependencies and follows the docling package as the source of truth.
|
||||
"""Default Source class for converting documents to markdown or json.
|
||||
|
||||
This will auto support PDF, DOCX, and TXT, XLSX, Images, and HTML files without
|
||||
any additional dependencies and follows the docling package as the source of truth.
|
||||
"""
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
def __init__(self, *args: Any, **kwargs: Any) -> None:
|
||||
if not DOCLING_AVAILABLE:
|
||||
raise ImportError(
|
||||
"The docling package is required to use CrewDoclingSource. "
|
||||
@@ -66,7 +65,7 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
)
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -99,6 +98,15 @@ class CrewDoclingSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add docling content asynchronously."""
|
||||
if self.content is None:
|
||||
return
|
||||
for doc in self.content:
|
||||
new_chunks_iterable = self._chunk_doc(doc)
|
||||
self.chunks.extend(list(new_chunks_iterable))
|
||||
await self._asave_documents()
|
||||
|
||||
def _convert_source_to_docling_documents(self) -> list[DoclingDocument]:
|
||||
conv_results_iter = self.document_converter.convert_all(self.safe_file_paths)
|
||||
return [result.document for result in conv_results_iter]
|
||||
|
||||
@@ -31,6 +31,15 @@ class CSVKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add CSV file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field, field_validator
|
||||
|
||||
@@ -26,7 +28,10 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
safe_file_paths: list[Path] = Field(default_factory=list)
|
||||
|
||||
@field_validator("file_path", "file_paths", mode="before")
|
||||
def validate_file_path(cls, v, info): # noqa: N805
|
||||
@classmethod
|
||||
def validate_file_path(
|
||||
cls, v: Path | list[Path] | str | list[str] | None, info: Any
|
||||
) -> Path | list[Path] | str | list[str] | None:
|
||||
"""Validate that at least one of file_path or file_paths is provided."""
|
||||
# Single check if both are None, O(1) instead of nested conditions
|
||||
if (
|
||||
@@ -69,7 +74,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
|
||||
return [self.convert_to_path(path) for path in path_list]
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate the paths."""
|
||||
for path in self.safe_file_paths:
|
||||
if not path.exists():
|
||||
@@ -86,7 +91,7 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
color="red",
|
||||
)
|
||||
|
||||
def model_post_init(self, _) -> None:
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
if self.file_path:
|
||||
self._logger.log(
|
||||
"warning",
|
||||
@@ -128,12 +133,12 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
"""Convert a path to a Path object."""
|
||||
return Path(KNOWLEDGE_DIRECTORY + "/" + path) if isinstance(path, str) else path
|
||||
|
||||
def _import_dependencies(self):
|
||||
def _import_dependencies(self) -> ModuleType:
|
||||
"""Dynamically import dependencies."""
|
||||
try:
|
||||
import pandas as pd # type: ignore[import-untyped,import-not-found]
|
||||
import pandas as pd # type: ignore[import-untyped]
|
||||
|
||||
return pd
|
||||
return pd # type: ignore[no-any-return]
|
||||
except ImportError as e:
|
||||
missing_package = str(e).split()[-1]
|
||||
raise ImportError(
|
||||
@@ -159,6 +164,20 @@ class ExcelKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add Excel file content asynchronously."""
|
||||
content_str = ""
|
||||
for value in self.content.values():
|
||||
if isinstance(value, dict):
|
||||
for sheet_value in value.values():
|
||||
content_str += str(sheet_value) + "\n"
|
||||
else:
|
||||
content_str += str(value) + "\n"
|
||||
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -44,6 +44,15 @@ class JSONKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add JSON file content asynchronously."""
|
||||
content_str = (
|
||||
str(self.content) if isinstance(self.content, dict) else self.content
|
||||
)
|
||||
new_chunks = self._chunk_text(content_str)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
from pathlib import Path
|
||||
from types import ModuleType
|
||||
|
||||
from crewai.knowledge.source.base_file_knowledge_source import BaseFileKnowledgeSource
|
||||
|
||||
@@ -23,7 +24,7 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
content[path] = text
|
||||
return content
|
||||
|
||||
def _import_pdfplumber(self):
|
||||
def _import_pdfplumber(self) -> ModuleType:
|
||||
"""Dynamically import pdfplumber."""
|
||||
try:
|
||||
import pdfplumber
|
||||
@@ -44,6 +45,13 @@ class PDFKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add PDF file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from crewai.knowledge.source.base_knowledge_source import BaseKnowledgeSource
|
||||
@@ -9,11 +11,11 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
content: str = Field(...)
|
||||
collection_name: str | None = Field(default=None)
|
||||
|
||||
def model_post_init(self, _):
|
||||
def model_post_init(self, _: Any) -> None:
|
||||
"""Post-initialization method to validate content."""
|
||||
self.validate_content()
|
||||
|
||||
def validate_content(self):
|
||||
def validate_content(self) -> None:
|
||||
"""Validate string content."""
|
||||
if not isinstance(self.content, str):
|
||||
raise ValueError("StringKnowledgeSource only accepts string content")
|
||||
@@ -24,6 +26,12 @@ class StringKnowledgeSource(BaseKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add string content asynchronously."""
|
||||
new_chunks = self._chunk_text(self.content)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -25,6 +25,13 @@ class TextFileKnowledgeSource(BaseFileKnowledgeSource):
|
||||
self.chunks.extend(new_chunks)
|
||||
self._save_documents()
|
||||
|
||||
async def aadd(self) -> None:
|
||||
"""Add text file content asynchronously."""
|
||||
for text in self.content.values():
|
||||
new_chunks = self._chunk_text(text)
|
||||
self.chunks.extend(new_chunks)
|
||||
await self._asave_documents()
|
||||
|
||||
def _chunk_text(self, text: str) -> list[str]:
|
||||
"""Utility method to split text into chunks."""
|
||||
return [
|
||||
|
||||
@@ -21,10 +21,28 @@ class BaseKnowledgeStorage(ABC):
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def save(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously."""
|
||||
|
||||
@abstractmethod
|
||||
def reset(self) -> None:
|
||||
"""Reset the knowledge base."""
|
||||
|
||||
@abstractmethod
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
|
||||
@@ -25,8 +25,8 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
def __init__(
|
||||
self,
|
||||
embedder: ProviderSpec
|
||||
| BaseEmbeddingsProvider
|
||||
| type[BaseEmbeddingsProvider]
|
||||
| BaseEmbeddingsProvider[Any]
|
||||
| type[BaseEmbeddingsProvider[Any]]
|
||||
| None = None,
|
||||
collection_name: str | None = None,
|
||||
) -> None:
|
||||
@@ -127,3 +127,96 @@ class KnowledgeStorage(BaseKnowledgeStorage):
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: list[str],
|
||||
limit: int = 5,
|
||||
metadata_filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[SearchResult]:
|
||||
"""Search for documents in the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
query: List of query strings.
|
||||
limit: Maximum number of results to return.
|
||||
metadata_filter: Optional metadata filter for the search.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of search results.
|
||||
"""
|
||||
try:
|
||||
if not query:
|
||||
raise ValueError("Query cannot be empty")
|
||||
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
query_text = " ".join(query) if len(query) > 1 else query[0]
|
||||
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query_text,
|
||||
limit=limit,
|
||||
metadata_filter=metadata_filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
async def asave(self, documents: list[str]) -> None:
|
||||
"""Save documents to the knowledge base asynchronously.
|
||||
|
||||
Args:
|
||||
documents: List of document strings to save.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
rag_documents: list[BaseRecord] = [{"content": doc} for doc in documents]
|
||||
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=rag_documents
|
||||
)
|
||||
except Exception as e:
|
||||
if "dimension mismatch" in str(e).lower():
|
||||
Logger(verbose=True).log(
|
||||
"error",
|
||||
"Embedding dimension mismatch. This usually happens when mixing different embedding models. Try resetting the collection using `crewai reset-memories -a`",
|
||||
"red",
|
||||
)
|
||||
raise ValueError(
|
||||
"Embedding dimension mismatch. Make sure you're using the same embedding model "
|
||||
"across all operations with this collection."
|
||||
"Try resetting the collection using `crewai reset-memories -a`"
|
||||
) from e
|
||||
Logger(verbose=True).log("error", f"Failed to upsert documents: {e}", "red")
|
||||
raise
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the knowledge base asynchronously."""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"knowledge_{self.collection_name}"
|
||||
if self.collection_name
|
||||
else "knowledge"
|
||||
)
|
||||
await client.adelete_collection(collection_name=collection_name)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during knowledge reset: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
@@ -38,6 +38,8 @@ from crewai.events.types.agent_events import (
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_before_llm_call_hooks
|
||||
)
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_after_llm_call_hooks
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self) -> Self:
|
||||
@@ -246,6 +254,26 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
@property
|
||||
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
|
||||
"""Get the before_llm_call hooks for this agent."""
|
||||
return self._before_llm_call_hooks
|
||||
|
||||
@property
|
||||
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
|
||||
"""Get the after_llm_call hooks for this agent."""
|
||||
return self._after_llm_call_hooks
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Get the messages list for hook context compatibility."""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Get the current iteration count for hook context compatibility."""
|
||||
return self._iterations
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -504,7 +532,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
AgentFinish: The final result of the agent execution.
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self._iterations, self.max_iterations):
|
||||
@@ -526,6 +554,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
from_agent=self,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -67,6 +67,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from crewai.agent.core import Agent
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.llms.providers.anthropic.completion import AnthropicThinkingConfig
|
||||
from crewai.task import Task
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
from crewai.utilities.types import LLMMessage
|
||||
@@ -585,6 +586,7 @@ class LLM(BaseLLM):
|
||||
reasoning_effort: Literal["none", "low", "medium", "high"] | None = None,
|
||||
stream: bool = False,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize LLM instance.
|
||||
@@ -1642,6 +1644,10 @@ class LLM(BaseLLM):
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
@@ -1651,7 +1657,16 @@ class LLM(BaseLLM):
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
else:
|
||||
result = self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
@@ -1660,14 +1675,12 @@ class LLM(BaseLLM):
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
result = self._invoke_after_llm_call_hooks(
|
||||
messages, result, from_agent
|
||||
)
|
||||
|
||||
return result
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
|
||||
@@ -314,7 +314,7 @@ class BaseLLM(ABC):
|
||||
call_type: LLMCallType,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -586,3 +586,134 @@ class BaseLLM(ABC):
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
|
||||
def _invoke_before_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
from_agent: Agent | None = None,
|
||||
) -> bool:
|
||||
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations before
|
||||
making the actual LLM call when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages being sent to the LLM
|
||||
from_agent: The agent making the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
True if LLM call should proceed, False if blocked by hook
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
|
||||
... messages, from_agent
|
||||
... ):
|
||||
... raise ValueError("LLM call blocked by hook")
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None:
|
||||
return True
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
if not before_hooks:
|
||||
return True
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
)
|
||||
printer = Printer()
|
||||
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
printer.print(
|
||||
content="LLM call blocked by before_llm_call hook",
|
||||
color="yellow",
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in before_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _invoke_after_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
response: str,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations after
|
||||
receiving the LLM response when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages that were sent to the LLM
|
||||
response: The response from the LLM
|
||||
from_agent: The agent that made the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
The potentially modified response string
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and isinstance(result, str):
|
||||
... result = self._invoke_after_llm_call_hooks(
|
||||
... messages, result, from_agent
|
||||
... )
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None or not isinstance(response, str):
|
||||
return response
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_after_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
if not after_hooks:
|
||||
return response
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
response=response,
|
||||
)
|
||||
printer = Printer()
|
||||
modified_response = response
|
||||
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is not None and isinstance(result, str):
|
||||
modified_response = result
|
||||
hook_context.response = modified_response
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in after_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return modified_response
|
||||
|
||||
@@ -3,8 +3,9 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from anthropic.types import ThinkingBlock
|
||||
from pydantic import BaseModel
|
||||
|
||||
from crewai.events.types.llm_events import LLMCallType
|
||||
@@ -22,8 +23,7 @@ if TYPE_CHECKING:
|
||||
|
||||
try:
|
||||
from anthropic import Anthropic, AsyncAnthropic
|
||||
from anthropic.types import Message
|
||||
from anthropic.types.tool_use_block import ToolUseBlock
|
||||
from anthropic.types import Message, TextBlock, ThinkingBlock, ToolUseBlock
|
||||
import httpx
|
||||
except ImportError:
|
||||
raise ImportError(
|
||||
@@ -31,6 +31,11 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class AnthropicThinkingConfig(BaseModel):
|
||||
type: Literal["enabled", "disabled"]
|
||||
budget_tokens: int | None = None
|
||||
|
||||
|
||||
class AnthropicCompletion(BaseLLM):
|
||||
"""Anthropic native completion implementation.
|
||||
|
||||
@@ -52,6 +57,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
interceptor: BaseInterceptor[httpx.Request, httpx.Response] | None = None,
|
||||
thinking: AnthropicThinkingConfig | None = None,
|
||||
**kwargs: Any,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -97,6 +103,10 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.top_p = top_p
|
||||
self.stream = stream
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.thinking = thinking
|
||||
self.previous_thinking_blocks: list[ThinkingBlock] = []
|
||||
# Model-specific settings
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = True
|
||||
|
||||
@property
|
||||
@@ -187,6 +197,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
@@ -323,6 +336,12 @@ class AnthropicCompletion(BaseLLM):
|
||||
if tools and self.supports_tools:
|
||||
params["tools"] = self._convert_tools_for_interference(tools)
|
||||
|
||||
if self.thinking:
|
||||
if isinstance(self.thinking, AnthropicThinkingConfig):
|
||||
params["thinking"] = self.thinking.model_dump()
|
||||
else:
|
||||
params["thinking"] = self.thinking
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
@@ -362,6 +381,34 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return anthropic_tools
|
||||
|
||||
def _extract_thinking_block(
|
||||
self, content_block: Any
|
||||
) -> ThinkingBlock | dict[str, Any] | None:
|
||||
"""Extract and format thinking block from content block.
|
||||
|
||||
Args:
|
||||
content_block: Content block from Anthropic response
|
||||
|
||||
Returns:
|
||||
Dictionary with thinking block data including signature, or None if not a thinking block
|
||||
"""
|
||||
if content_block.type == "thinking":
|
||||
thinking_block = {
|
||||
"type": "thinking",
|
||||
"thinking": content_block.thinking,
|
||||
}
|
||||
if hasattr(content_block, "signature"):
|
||||
thinking_block["signature"] = content_block.signature
|
||||
return thinking_block
|
||||
if content_block.type == "redacted_thinking":
|
||||
redacted_block = {"type": "redacted_thinking"}
|
||||
if hasattr(content_block, "thinking"):
|
||||
redacted_block["thinking"] = content_block.thinking
|
||||
if hasattr(content_block, "signature"):
|
||||
redacted_block["signature"] = content_block.signature
|
||||
return redacted_block
|
||||
return None
|
||||
|
||||
def _format_messages_for_anthropic(
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[LLMMessage], str | None]:
|
||||
@@ -371,6 +418,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
- System messages are separate from conversation messages
|
||||
- Messages must alternate between user and assistant
|
||||
- First message must be from user
|
||||
- When thinking is enabled, assistant messages must start with thinking blocks
|
||||
|
||||
Args:
|
||||
messages: Input messages
|
||||
@@ -395,8 +443,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
system_message = cast(str, content)
|
||||
else:
|
||||
role_str = role if role is not None else "user"
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append({"role": role_str, "content": content_str})
|
||||
|
||||
if isinstance(content, list):
|
||||
formatted_messages.append({"role": role_str, "content": content})
|
||||
elif (
|
||||
role_str == "assistant"
|
||||
and self.thinking
|
||||
and self.previous_thinking_blocks
|
||||
):
|
||||
structured_content = cast(
|
||||
list[dict[str, Any]],
|
||||
[
|
||||
*self.previous_thinking_blocks,
|
||||
{"type": "text", "text": content if content else ""},
|
||||
],
|
||||
)
|
||||
formatted_messages.append(
|
||||
LLMMessage(role=role_str, content=structured_content)
|
||||
)
|
||||
else:
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append(
|
||||
LLMMessage(role=role_str, content=content_str)
|
||||
)
|
||||
|
||||
# Ensure first message is from user (Anthropic requirement)
|
||||
if not formatted_messages:
|
||||
@@ -446,7 +515,6 @@ class AnthropicCompletion(BaseLLM):
|
||||
if tool_uses and tool_uses[0].name == "structured_output":
|
||||
structured_data = tool_uses[0].input
|
||||
structured_json = json.dumps(structured_data)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -474,15 +542,22 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Extract text content
|
||||
content = ""
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
|
||||
if response.content:
|
||||
for content_block in response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
content += content_block.text
|
||||
else:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
@@ -494,7 +569,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
@@ -535,6 +612,16 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
if final_message.content:
|
||||
for content_block in final_message.content:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
@@ -588,7 +675,52 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
def _execute_tools_and_collect_results(
|
||||
self,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Execute tools and collect results in Anthropic format.
|
||||
|
||||
Args:
|
||||
tool_uses: List of tool use blocks from Claude's response
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
List of tool result dictionaries in Anthropic format
|
||||
"""
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=cast(dict[str, Any], function_args),
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
return tool_results
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
@@ -607,37 +739,33 @@ class AnthropicCompletion(BaseLLM):
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
# Execute all requested tools and collect results
|
||||
tool_results = []
|
||||
tool_results = self._execute_tools_and_collect_results(
|
||||
tool_uses, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Create tool result in Anthropic format
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
# Prepare follow-up conversation with tool results
|
||||
follow_up_params = params.copy()
|
||||
|
||||
# Add Claude's tool use response to conversation
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
assistant_content: list[
|
||||
ThinkingBlock | ToolUseBlock | TextBlock | dict[str, Any]
|
||||
] = []
|
||||
for block in initial_response.content:
|
||||
thinking_block = self._extract_thinking_block(block)
|
||||
if thinking_block:
|
||||
assistant_content.append(thinking_block)
|
||||
elif block.type == "tool_use":
|
||||
assistant_content.append(
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": block.id,
|
||||
"name": block.name,
|
||||
"input": block.input,
|
||||
}
|
||||
)
|
||||
elif hasattr(block, "text"):
|
||||
assistant_content.append({"type": "text", "text": block.text})
|
||||
|
||||
assistant_message = {"role": "assistant", "content": assistant_content}
|
||||
|
||||
# Add user message with tool results
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
@@ -656,12 +784,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
# Extract final text content
|
||||
final_content = ""
|
||||
thinking_blocks: list[ThinkingBlock] = []
|
||||
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
else:
|
||||
thinking_block = self._extract_thinking_block(content_block)
|
||||
if thinking_block:
|
||||
thinking_blocks.append(cast(ThinkingBlock, thinking_block))
|
||||
|
||||
if thinking_blocks:
|
||||
self.previous_thinking_blocks = thinking_blocks
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
@@ -694,7 +830,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
# Fallback: return the first tool result if follow-up fails
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
return cast(str, tool_results[0]["content"])
|
||||
raise e
|
||||
|
||||
async def _ahandle_completion(
|
||||
@@ -887,28 +1023,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
tool_results = self._execute_tools_and_collect_results(
|
||||
tool_uses, available_functions, from_task, from_agent
|
||||
)
|
||||
|
||||
follow_up_params = params.copy()
|
||||
|
||||
@@ -963,7 +1080,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
return cast(str, tool_results[0]["content"])
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
@@ -999,7 +1116,8 @@ class AnthropicCompletion(BaseLLM):
|
||||
# Default context window size for Claude models
|
||||
return int(200000 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_anthropic_token_usage(self, response: Message) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_anthropic_token_usage(response: Message) -> dict[str, Any]:
|
||||
"""Extract token usage from Anthropic response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, TypedDict
|
||||
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Self
|
||||
@@ -18,7 +18,6 @@ from crewai.utilities.types import LLMMessage
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.llms.hooks.base import BaseInterceptor
|
||||
from crewai.tools.base_tool import BaseTool
|
||||
|
||||
|
||||
try:
|
||||
@@ -31,6 +30,8 @@ try:
|
||||
from azure.ai.inference.models import (
|
||||
ChatCompletions,
|
||||
ChatCompletionsToolCall,
|
||||
ChatCompletionsToolDefinition,
|
||||
FunctionDefinition,
|
||||
JsonSchemaFormat,
|
||||
StreamingChatCompletionsUpdate,
|
||||
)
|
||||
@@ -50,6 +51,24 @@ except ImportError:
|
||||
) from None
|
||||
|
||||
|
||||
class AzureCompletionParams(TypedDict, total=False):
|
||||
"""Type definition for Azure chat completion parameters."""
|
||||
|
||||
messages: list[LLMMessage]
|
||||
stream: bool
|
||||
model_extras: dict[str, Any]
|
||||
response_format: JsonSchemaFormat
|
||||
model: str
|
||||
temperature: float
|
||||
top_p: float
|
||||
frequency_penalty: float
|
||||
presence_penalty: float
|
||||
max_tokens: int
|
||||
stop: list[str]
|
||||
tools: list[ChatCompletionsToolDefinition]
|
||||
tool_choice: str
|
||||
|
||||
|
||||
class AzureCompletion(BaseLLM):
|
||||
"""Azure AI Inference native completion implementation.
|
||||
|
||||
@@ -156,7 +175,8 @@ class AzureCompletion(BaseLLM):
|
||||
and "/openai/deployments/" in self.endpoint
|
||||
)
|
||||
|
||||
def _validate_and_fix_endpoint(self, endpoint: str, model: str) -> str:
|
||||
@staticmethod
|
||||
def _validate_and_fix_endpoint(endpoint: str, model: str) -> str:
|
||||
"""Validate and fix Azure endpoint URL format.
|
||||
|
||||
Azure OpenAI endpoints should be in the format:
|
||||
@@ -179,10 +199,75 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
return endpoint
|
||||
|
||||
def _handle_api_error(
|
||||
self,
|
||||
error: Exception,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None:
|
||||
"""Handle API errors with appropriate logging and events.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Raises:
|
||||
The original exception after logging and emitting events
|
||||
"""
|
||||
if isinstance(error, HttpResponseError):
|
||||
if error.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif error.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif error.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = (
|
||||
f"Azure API HTTP error: {error.status_code} - {error.message}"
|
||||
)
|
||||
else:
|
||||
error_msg = f"Azure API call failed: {error!s}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise error
|
||||
|
||||
def _handle_completion_error(
|
||||
self,
|
||||
error: Exception,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> None:
|
||||
"""Handle completion-specific errors including context length checks.
|
||||
|
||||
Args:
|
||||
error: The exception that occurred
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Raises:
|
||||
LLMContextLengthExceededError if context window exceeded, otherwise the original exception
|
||||
"""
|
||||
if is_context_length_exceeded(error):
|
||||
logging.error(f"Context window exceeded: {error}")
|
||||
raise LLMContextLengthExceededError(str(error)) from error
|
||||
|
||||
error_msg = f"Azure API call failed: {error!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise error
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -198,6 +283,7 @@ class AzureCompletion(BaseLLM):
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
@@ -216,6 +302,9 @@ class AzureCompletion(BaseLLM):
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
@@ -239,35 +328,13 @@ class AzureCompletion(BaseLLM):
|
||||
response_model,
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
return self._handle_api_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
||||
|
||||
async def acall(
|
||||
async def acall( # type: ignore[return]
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
tools: list[dict[str, BaseTool]] | None = None,
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
callbacks: list[Any] | None = None,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -321,37 +388,15 @@ class AzureCompletion(BaseLLM):
|
||||
response_model,
|
||||
)
|
||||
|
||||
except HttpResponseError as e:
|
||||
if e.status_code == 401:
|
||||
error_msg = "Azure authentication failed. Check your API key."
|
||||
elif e.status_code == 404:
|
||||
error_msg = (
|
||||
f"Azure endpoint not found. Check endpoint URL: {self.endpoint}"
|
||||
)
|
||||
elif e.status_code == 429:
|
||||
error_msg = "Azure API rate limit exceeded. Please retry later."
|
||||
else:
|
||||
error_msg = f"Azure API HTTP error: {e.status_code} - {e.message}"
|
||||
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
except Exception as e:
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise
|
||||
self._handle_api_error(e, from_task, from_agent)
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
tools: list[dict[str, Any]] | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> dict[str, Any]:
|
||||
) -> AzureCompletionParams:
|
||||
"""Prepare parameters for Azure AI Inference chat completion.
|
||||
|
||||
Args:
|
||||
@@ -362,11 +407,14 @@ class AzureCompletion(BaseLLM):
|
||||
Returns:
|
||||
Parameters dictionary for Azure API
|
||||
"""
|
||||
params = {
|
||||
params: AzureCompletionParams = {
|
||||
"messages": messages,
|
||||
"stream": self.stream,
|
||||
}
|
||||
|
||||
if self.stream:
|
||||
params["model_extras"] = {"stream_options": {"include_usage": True}}
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
model_description = generate_model_description(response_model)
|
||||
json_schema_info = model_description["json_schema"]
|
||||
@@ -409,37 +457,42 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
if drop_params and isinstance(additional_drop_params, list):
|
||||
for drop_param in additional_drop_params:
|
||||
params.pop(drop_param, None)
|
||||
if isinstance(drop_param, str):
|
||||
params.pop(drop_param, None) # type: ignore[misc]
|
||||
|
||||
return params
|
||||
|
||||
def _convert_tools_for_interference(
|
||||
def _convert_tools_for_interference( # type: ignore[override]
|
||||
self, tools: list[dict[str, Any]]
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Convert CrewAI tool format to Azure OpenAI function calling format."""
|
||||
) -> list[ChatCompletionsToolDefinition]:
|
||||
"""Convert CrewAI tool format to Azure OpenAI function calling format.
|
||||
|
||||
Args:
|
||||
tools: List of CrewAI tool definitions
|
||||
|
||||
Returns:
|
||||
List of Azure ChatCompletionsToolDefinition objects
|
||||
"""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
azure_tools = []
|
||||
azure_tools: list[ChatCompletionsToolDefinition] = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Azure")
|
||||
|
||||
azure_tool = {
|
||||
"type": "function",
|
||||
"function": {
|
||||
"name": name,
|
||||
"description": description,
|
||||
},
|
||||
}
|
||||
function_def = FunctionDefinition(
|
||||
name=name,
|
||||
description=description,
|
||||
parameters=parameters
|
||||
if isinstance(parameters, dict)
|
||||
else dict(parameters)
|
||||
if parameters
|
||||
else None,
|
||||
)
|
||||
|
||||
if parameters:
|
||||
if isinstance(parameters, dict):
|
||||
azure_tool["function"]["parameters"] = parameters # type: ignore
|
||||
else:
|
||||
azure_tool["function"]["parameters"] = dict(parameters)
|
||||
tool_def = ChatCompletionsToolDefinition(function=function_def)
|
||||
|
||||
azure_tools.append(azure_tool)
|
||||
azure_tools.append(tool_def)
|
||||
|
||||
return azure_tools
|
||||
|
||||
@@ -468,144 +521,239 @@ class AzureCompletion(BaseLLM):
|
||||
|
||||
return azure_messages
|
||||
|
||||
def _handle_completion(
|
||||
def _validate_and_emit_structured_output(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
content: str,
|
||||
response_model: type[BaseModel],
|
||||
params: AzureCompletionParams,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
# Make API call
|
||||
) -> str:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
Args:
|
||||
content: Response content to validate
|
||||
response_model: Pydantic model for validation
|
||||
params: Completion parameters containing messages
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Validated and serialized JSON string
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
try:
|
||||
response: ChatCompletions = self.client.complete(**params)
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Extract and track token usage
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Extract content
|
||||
content = message.content or ""
|
||||
|
||||
# Apply stop words
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
# Emit completion event and return content
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
def _process_completion_response(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
response: ChatCompletions,
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Process completion response with usage tracking, tool execution, and events.
|
||||
|
||||
Args:
|
||||
response: Chat completion response from Azure API
|
||||
params: Completion parameters containing messages
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output
|
||||
|
||||
Returns:
|
||||
Response content or structured output
|
||||
"""
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
# Extract and track token usage
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
return self._validate_and_emit_structured_output(
|
||||
content=content,
|
||||
response_model=response_model,
|
||||
params=params,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool calls
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Extract content
|
||||
content = message.content or ""
|
||||
|
||||
# Apply stop words
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
# Emit completion event and return content
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
def _handle_completion(
|
||||
self,
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion."""
|
||||
try:
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = self.client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
||||
|
||||
def _process_streaming_update(
|
||||
self,
|
||||
update: StreamingChatCompletionsUpdate,
|
||||
full_response: str,
|
||||
tool_calls: dict[str, dict[str, str]],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
"""Process a single streaming update chunk.
|
||||
|
||||
# Make streaming API call
|
||||
for update in self.client.complete(**params):
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
Args:
|
||||
update: Streaming update from Azure API
|
||||
full_response: Accumulated response content
|
||||
tool_calls: Dictionary of accumulated tool calls
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
# Handle tool call streaming
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
Returns:
|
||||
Updated full_response string
|
||||
"""
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
return full_response
|
||||
|
||||
def _finalize_streaming_response(
|
||||
self,
|
||||
full_response: str,
|
||||
tool_calls: dict[str, dict[str, str]],
|
||||
usage_data: dict[str, int],
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Finalize streaming response with usage tracking, tool execution, and events.
|
||||
|
||||
Args:
|
||||
full_response: The complete streamed response content
|
||||
tool_calls: Dictionary of tool calls accumulated during streaming
|
||||
usage_data: Token usage data from the stream
|
||||
params: Completion parameters containing messages
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output validation
|
||||
|
||||
Returns:
|
||||
Final response content after processing, or structured output
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
# Handle structured output validation
|
||||
if response_model and self.is_openai_model:
|
||||
return self._validate_and_emit_structured_output(
|
||||
content=full_response,
|
||||
response_model=response_model,
|
||||
params=params,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle completed tool calls
|
||||
if tool_calls and available_functions:
|
||||
@@ -642,11 +790,56 @@ class AzureCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle streaming chat completion."""
|
||||
full_response = ""
|
||||
tool_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
for update in self.client.complete(**params): # type: ignore[arg-type]
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.usage:
|
||||
usage = update.usage
|
||||
usage_data = {
|
||||
"prompt_tokens": usage.prompt_tokens,
|
||||
"completion_tokens": usage.completion_tokens,
|
||||
"total_tokens": usage.total_tokens,
|
||||
}
|
||||
continue
|
||||
|
||||
full_response = self._process_streaming_update(
|
||||
update=update,
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
return self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
usage_data=usage_data,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
@@ -654,160 +847,64 @@ class AzureCompletion(BaseLLM):
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming chat completion asynchronously."""
|
||||
try:
|
||||
response: ChatCompletions = await self.async_client.complete(**params)
|
||||
|
||||
if not response.choices:
|
||||
raise ValueError("No choices returned from Azure API")
|
||||
|
||||
choice = response.choices[0]
|
||||
message = choice.message
|
||||
|
||||
usage = self._extract_azure_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response_model and self.is_openai_model:
|
||||
content = message.content or ""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
if message.tool_calls and available_functions:
|
||||
tool_call = message.tool_calls[0] # Handle first tool call
|
||||
if isinstance(tool_call, ChatCompletionsToolCall):
|
||||
function_name = tool_call.function.name
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_call.function.arguments)
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse tool arguments: {e}")
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = message.content or ""
|
||||
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
# Cast params to Any to avoid type checking issues with TypedDict unpacking
|
||||
response: ChatCompletions = await self.async_client.complete(**params) # type: ignore[assignment,arg-type]
|
||||
return self._process_completion_response(
|
||||
response=response,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
error_msg = f"Azure API call failed: {e!s}"
|
||||
logging.error(error_msg)
|
||||
self._emit_call_failed_event(
|
||||
error=error_msg, from_task=from_task, from_agent=from_agent
|
||||
)
|
||||
raise e
|
||||
|
||||
return content
|
||||
return self._handle_completion_error(e, from_task, from_agent) # type: ignore[func-returns-value]
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
params: dict[str, Any],
|
||||
params: AzureCompletionParams,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
) -> str | Any:
|
||||
"""Handle streaming chat completion asynchronously."""
|
||||
full_response = ""
|
||||
tool_calls = {}
|
||||
tool_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
stream = await self.async_client.complete(**params)
|
||||
async for update in stream:
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
stream = await self.async_client.complete(**params) # type: ignore[arg-type]
|
||||
async for update in stream: # type: ignore[union-attr]
|
||||
if isinstance(update, StreamingChatCompletionsUpdate):
|
||||
if update.choices:
|
||||
choice = update.choices[0]
|
||||
if choice.delta and choice.delta.content:
|
||||
content_delta = choice.delta.content
|
||||
full_response += content_delta
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=content_delta,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if choice.delta and choice.delta.tool_calls:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
call_id = tool_call.id or "default"
|
||||
if call_id not in tool_calls:
|
||||
tool_calls[call_id] = {
|
||||
"name": "",
|
||||
"arguments": "",
|
||||
}
|
||||
|
||||
if tool_call.function and tool_call.function.name:
|
||||
tool_calls[call_id]["name"] = tool_call.function.name
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += (
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
|
||||
try:
|
||||
function_args = json.loads(call_data["arguments"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
if hasattr(update, "usage") and update.usage:
|
||||
usage = update.usage
|
||||
usage_data = {
|
||||
"prompt_tokens": getattr(usage, "prompt_tokens", 0),
|
||||
"completion_tokens": getattr(usage, "completion_tokens", 0),
|
||||
"total_tokens": getattr(usage, "total_tokens", 0),
|
||||
}
|
||||
continue
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
full_response = self._process_streaming_update(
|
||||
update=update,
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
tool_calls=tool_calls,
|
||||
usage_data=usage_data,
|
||||
params=params,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
# Azure OpenAI models support function calling
|
||||
@@ -851,7 +948,8 @@ class AzureCompletion(BaseLLM):
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_azure_token_usage(self, response: ChatCompletions) -> dict[str, Any]:
|
||||
@staticmethod
|
||||
def _extract_azure_token_usage(response: ChatCompletions) -> dict[str, Any]:
|
||||
"""Extract token usage from Azure response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
|
||||
@@ -312,9 +312,14 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
cast(list[LLMMessage], formatted_messages), from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
@@ -356,11 +361,19 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -481,7 +494,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -605,7 +618,11 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
text_content,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
# Handle all AWS ClientError exceptions as per documentation
|
||||
@@ -662,7 +679,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -1149,16 +1166,25 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
full_response,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
def _format_messages_for_converse(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Format messages for Converse API following AWS documentation."""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
|
||||
"""Format messages for Converse API following AWS documentation.
|
||||
|
||||
converse_messages = []
|
||||
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
||||
a different content structure: {"role": str, "content": [{"text": str}]}
|
||||
rather than the standard {"role": str, "content": str}.
|
||||
"""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
converse_messages: list[dict[str, Any]] = []
|
||||
system_message: str | None = None
|
||||
|
||||
for message in formatted_messages:
|
||||
|
||||
@@ -3,7 +3,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal, cast
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -105,6 +105,7 @@ class GeminiCompletion(BaseLLM):
|
||||
self.stream = stream
|
||||
self.safety_settings = safety_settings or {}
|
||||
self.stop_sequences = stop_sequences or []
|
||||
self.tools: list[dict[str, Any]] | None = None
|
||||
|
||||
# Model-specific settings
|
||||
version_match = re.search(r"gemini-(\d+(?:\.\d+)?)", model.lower())
|
||||
@@ -223,10 +224,11 @@ class GeminiCompletion(BaseLLM):
|
||||
Args:
|
||||
messages: Input messages for the chat completion
|
||||
tools: List of tool/function definitions
|
||||
callbacks: Callback functions (not used as token counts are handled by the reponse)
|
||||
callbacks: Callback functions (not used as token counts are handled by the response)
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model to use.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
@@ -246,6 +248,11 @@ class GeminiCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
@@ -262,7 +269,6 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return self._handle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
@@ -304,6 +310,7 @@ class GeminiCompletion(BaseLLM):
|
||||
available_functions: Available functions for tool calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Response model to use.
|
||||
|
||||
Returns:
|
||||
Chat completion response or tool call result
|
||||
@@ -339,7 +346,6 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return await self._ahandle_completion(
|
||||
formatted_content,
|
||||
system_instruction,
|
||||
config,
|
||||
available_functions,
|
||||
from_task,
|
||||
@@ -492,35 +498,113 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
return contents, system_instruction
|
||||
|
||||
def _handle_completion(
|
||||
def _validate_and_emit_structured_output(
|
||||
self,
|
||||
content: str,
|
||||
response_model: type[BaseModel],
|
||||
messages_for_event: list[LLMMessage],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Validate content against response model and emit completion event.
|
||||
|
||||
Args:
|
||||
content: Response content to validate
|
||||
response_model: Pydantic model for validation
|
||||
messages_for_event: Messages to include in event
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Validated and serialized JSON string
|
||||
|
||||
Raises:
|
||||
ValueError: If validation fails
|
||||
"""
|
||||
try:
|
||||
structured_data = response_model.model_validate_json(content)
|
||||
structured_json = structured_data.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to validate structured output with model {response_model.__name__}: {e}"
|
||||
logging.error(error_msg)
|
||||
raise ValueError(error_msg) from e
|
||||
|
||||
def _finalize_completion_response(
|
||||
self,
|
||||
content: str,
|
||||
contents: list[types.Content],
|
||||
response_model: type[BaseModel] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Finalize completion response with validation and event emission.
|
||||
|
||||
Args:
|
||||
content: The response content
|
||||
contents: Original contents for event conversion
|
||||
response_model: Pydantic model for structured output validation
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Final response content after processing
|
||||
"""
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
# Handle structured output validation
|
||||
if response_model:
|
||||
return self._validate_and_emit_structured_output(
|
||||
content=content,
|
||||
response_model=response_model,
|
||||
messages_for_event=messages_for_event,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, content, from_agent
|
||||
)
|
||||
|
||||
def _process_response_with_tools(
|
||||
self,
|
||||
response: GenerateContentResponse,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
"""Process response, execute function calls, and finalize completion.
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
Args:
|
||||
response: The completion response
|
||||
contents: Original contents for event conversion
|
||||
available_functions: Available functions for function calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output validation
|
||||
|
||||
Returns:
|
||||
Final response content or function call result
|
||||
"""
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
@@ -549,59 +633,90 @@ class GeminiCompletion(BaseLLM):
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._finalize_completion_response(
|
||||
content=content,
|
||||
contents=contents,
|
||||
response_model=response_model,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
def _handle_streaming_completion(
|
||||
def _process_stream_chunk(
|
||||
self,
|
||||
chunk: GenerateContentResponse,
|
||||
full_response: str,
|
||||
function_calls: dict[str, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> tuple[str, dict[str, dict[str, Any]], dict[str, int]]:
|
||||
"""Process a single streaming chunk.
|
||||
|
||||
Args:
|
||||
chunk: The streaming chunk response
|
||||
full_response: Accumulated response text
|
||||
function_calls: Accumulated function calls
|
||||
usage_data: Accumulated usage data
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
|
||||
Returns:
|
||||
Tuple of (updated full_response, updated function_calls, updated usage_data)
|
||||
"""
|
||||
if chunk.usage_metadata:
|
||||
usage_data = self._extract_token_usage(chunk)
|
||||
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
return full_response, function_calls, usage_data
|
||||
|
||||
def _finalize_streaming_response(
|
||||
self,
|
||||
full_response: str,
|
||||
function_calls: dict[str, dict[str, Any]],
|
||||
usage_data: dict[str, int],
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
"""Finalize streaming response with usage tracking, function execution, and events.
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
):
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
Args:
|
||||
full_response: The complete streamed response content
|
||||
function_calls: Dictionary of function calls accumulated during streaming
|
||||
usage_data: Token usage data from the stream
|
||||
contents: Original contents for event conversion
|
||||
available_functions: Available functions for function calling
|
||||
from_task: Task that initiated the call
|
||||
from_agent: Agent that initiated the call
|
||||
response_model: Pydantic model for structured output validation
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
Returns:
|
||||
Final response content after processing
|
||||
"""
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
# Handle completed function calls
|
||||
if function_calls and available_functions:
|
||||
@@ -629,22 +744,95 @@ class GeminiCompletion(BaseLLM):
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._finalize_completion_response(
|
||||
content=full_response,
|
||||
contents=contents,
|
||||
response_model=response_model,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
def _handle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = self.client.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
return self._process_response_with_tools(
|
||||
response=response,
|
||||
contents=contents,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
for chunk in self.client.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
):
|
||||
full_response, function_calls, usage_data = self._process_stream_chunk(
|
||||
chunk=chunk,
|
||||
full_response=full_response,
|
||||
function_calls=function_calls,
|
||||
usage_data=usage_data,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
return self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
function_calls=function_calls,
|
||||
usage_data=usage_data,
|
||||
contents=contents,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -670,46 +858,15 @@ class GeminiCompletion(BaseLLM):
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._process_response_with_tools(
|
||||
response=response,
|
||||
contents=contents,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
@@ -722,6 +879,7 @@ class GeminiCompletion(BaseLLM):
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
@@ -731,64 +889,26 @@ class GeminiCompletion(BaseLLM):
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
full_response, function_calls, usage_data = self._process_stream_chunk(
|
||||
chunk=chunk,
|
||||
full_response=full_response,
|
||||
function_calls=function_calls,
|
||||
usage_data=usage_data,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
return self._finalize_streaming_response(
|
||||
full_response=full_response,
|
||||
function_calls=function_calls,
|
||||
usage_data=usage_data,
|
||||
contents=contents,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return full_response
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
@@ -848,12 +968,12 @@ class GeminiCompletion(BaseLLM):
|
||||
}
|
||||
return {"total_tokens": 0}
|
||||
|
||||
@staticmethod
|
||||
def _convert_contents_to_dict(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
) -> list[LLMMessage]:
|
||||
"""Convert contents to dict format."""
|
||||
result: list[dict[str, str]] = []
|
||||
result: list[LLMMessage] = []
|
||||
for content_obj in contents:
|
||||
role = content_obj.role
|
||||
if role == "model":
|
||||
@@ -866,5 +986,10 @@ class GeminiCompletion(BaseLLM):
|
||||
part.text for part in parts if hasattr(part, "text") and part.text
|
||||
)
|
||||
|
||||
result.append({"role": role, "content": content})
|
||||
result.append(
|
||||
LLMMessage(
|
||||
role=cast(Literal["user", "assistant", "system"], role),
|
||||
content=content,
|
||||
)
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from collections.abc import AsyncIterator, Iterator
|
||||
from collections.abc import AsyncIterator
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
import httpx
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI
|
||||
from openai import APIConnectionError, AsyncOpenAI, NotFoundError, OpenAI, Stream
|
||||
from openai.lib.streaming.chat import ChatCompletionStream
|
||||
from openai.types.chat import ChatCompletion, ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChoiceDelta
|
||||
@@ -189,6 +190,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
@@ -293,6 +297,7 @@ class OpenAICompletion(BaseLLM):
|
||||
}
|
||||
if self.stream:
|
||||
params["stream"] = self.stream
|
||||
params["stream_options"] = {"include_usage": True}
|
||||
|
||||
params.update(self.additional_params)
|
||||
|
||||
@@ -473,6 +478,10 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
@@ -515,59 +524,61 @@ class OpenAICompletion(BaseLLM):
|
||||
tool_calls = {}
|
||||
|
||||
if response_model:
|
||||
completion_stream: Iterator[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
parse_params = {
|
||||
k: v
|
||||
for k, v in params.items()
|
||||
if k not in ("response_format", "stream")
|
||||
}
|
||||
|
||||
accumulated_content = ""
|
||||
for chunk in completion_stream:
|
||||
if not chunk.choices:
|
||||
continue
|
||||
stream: ChatCompletionStream[BaseModel]
|
||||
with self.client.beta.chat.completions.stream(
|
||||
**parse_params, response_format=response_model
|
||||
) as stream:
|
||||
for chunk in stream:
|
||||
if chunk.type == "content.delta":
|
||||
delta_content = chunk.delta
|
||||
if delta_content:
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta_content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
choice = chunk.choices[0]
|
||||
delta: ChoiceDelta = choice.delta
|
||||
final_completion = stream.get_final_completion()
|
||||
if final_completion:
|
||||
usage = self._extract_openai_token_usage(final_completion)
|
||||
self._track_token_usage_internal(usage)
|
||||
if final_completion.choices:
|
||||
parsed_result = final_completion.choices[0].message.parsed
|
||||
if parsed_result:
|
||||
structured_json = parsed_result.model_dump_json()
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return structured_json
|
||||
|
||||
if delta.content:
|
||||
accumulated_content += delta.content
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=delta.content,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
logging.error("Failed to get parsed result from stream")
|
||||
return ""
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=structured_json,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return structured_json
|
||||
except Exception as e:
|
||||
logging.error(f"Failed to parse structured output from stream: {e}")
|
||||
self._emit_call_completed_event(
|
||||
response=accumulated_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=params["messages"],
|
||||
)
|
||||
return accumulated_content
|
||||
|
||||
stream: Iterator[ChatCompletionChunk] = self.client.chat.completions.create(
|
||||
**params
|
||||
completion_stream: Stream[ChatCompletionChunk] = (
|
||||
self.client.chat.completions.create(**params)
|
||||
)
|
||||
|
||||
for chunk in stream:
|
||||
if not chunk.choices:
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
for completion_chunk in completion_stream:
|
||||
if hasattr(completion_chunk, "usage") and completion_chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(completion_chunk)
|
||||
continue
|
||||
|
||||
choice = chunk.choices[0]
|
||||
if not completion_chunk.choices:
|
||||
continue
|
||||
|
||||
choice = completion_chunk.choices[0]
|
||||
chunk_delta: ChoiceDelta = choice.delta
|
||||
|
||||
if chunk_delta.content:
|
||||
@@ -592,6 +603,8 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
@@ -635,7 +648,9 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
@@ -782,7 +797,12 @@ class OpenAICompletion(BaseLLM):
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
accumulated_content = ""
|
||||
usage_data = {"total_tokens": 0}
|
||||
async for chunk in completion_stream:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
continue
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -797,6 +817,8 @@ class OpenAICompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
try:
|
||||
parsed_object = response_model.model_validate_json(accumulated_content)
|
||||
structured_json = parsed_object.model_dump_json()
|
||||
@@ -825,7 +847,13 @@ class OpenAICompletion(BaseLLM):
|
||||
ChatCompletionChunk
|
||||
] = await self.async_client.chat.completions.create(**params)
|
||||
|
||||
usage_data = {"total_tokens": 0}
|
||||
|
||||
async for chunk in stream:
|
||||
if hasattr(chunk, "usage") and chunk.usage:
|
||||
usage_data = self._extract_openai_token_usage(chunk)
|
||||
continue
|
||||
|
||||
if not chunk.choices:
|
||||
continue
|
||||
|
||||
@@ -854,6 +882,8 @@ class OpenAICompletion(BaseLLM):
|
||||
if tool_call.function and tool_call.function.arguments:
|
||||
tool_calls[call_id]["arguments"] += tool_call.function.arguments
|
||||
|
||||
self._track_token_usage_internal(usage_data)
|
||||
|
||||
if tool_calls and available_functions:
|
||||
for call_data in tool_calls.values():
|
||||
function_name = call_data["name"]
|
||||
@@ -941,8 +971,10 @@ class OpenAICompletion(BaseLLM):
|
||||
# Default context window size
|
||||
return int(8192 * CONTEXT_WINDOW_USAGE_RATIO)
|
||||
|
||||
def _extract_openai_token_usage(self, response: ChatCompletion) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion response."""
|
||||
def _extract_openai_token_usage(
|
||||
self, response: ChatCompletion | ChatCompletionChunk
|
||||
) -> dict[str, Any]:
|
||||
"""Extract token usage from OpenAI ChatCompletion or ChatCompletionChunk response."""
|
||||
if hasattr(response, "usage") and response.usage:
|
||||
usage = response.usage
|
||||
return {
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from crewai.memory import (
|
||||
@@ -16,6 +17,8 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class ContextualMemory:
|
||||
"""Aggregates and retrieves context from multiple memory sources."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
stm: ShortTermMemory,
|
||||
@@ -46,9 +49,14 @@ class ContextualMemory:
|
||||
self.exm.task = self.task
|
||||
|
||||
def build_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""
|
||||
Automatically builds a minimal, highly relevant set of contextual information
|
||||
for a given task.
|
||||
"""Build contextual information for a task synchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
@@ -63,6 +71,31 @@ class ContextualMemory:
|
||||
]
|
||||
return "\n".join(filter(None, context_parts))
|
||||
|
||||
async def abuild_context_for_task(self, task: Task, context: str) -> str:
|
||||
"""Build contextual information for a task asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task to build context for.
|
||||
context: Additional context string.
|
||||
|
||||
Returns:
|
||||
Formatted context string from all memory sources.
|
||||
"""
|
||||
query = f"{task.description} {context}".strip()
|
||||
|
||||
if query == "":
|
||||
return ""
|
||||
|
||||
# Fetch all contexts concurrently
|
||||
results = await asyncio.gather(
|
||||
self._afetch_ltm_context(task.description),
|
||||
self._afetch_stm_context(query),
|
||||
self._afetch_entity_context(query),
|
||||
self._afetch_external_context(query),
|
||||
)
|
||||
|
||||
return "\n".join(filter(None, results))
|
||||
|
||||
def _fetch_stm_context(self, query: str) -> str:
|
||||
"""
|
||||
Fetches recent relevant insights from STM related to the task's description and expected_output,
|
||||
@@ -135,3 +168,87 @@ class ContextualMemory:
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
async def _afetch_stm_context(self, query: str) -> str:
|
||||
"""Fetch recent relevant insights from STM asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted insights as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.stm is None:
|
||||
return ""
|
||||
|
||||
stm_results = await self.stm.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in stm_results]
|
||||
)
|
||||
return f"Recent Insights:\n{formatted_results}" if stm_results else ""
|
||||
|
||||
async def _afetch_ltm_context(self, task: str) -> str | None:
|
||||
"""Fetch historical data from LTM asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
|
||||
Returns:
|
||||
Formatted historical data as bullet points, or None if none found.
|
||||
"""
|
||||
if self.ltm is None:
|
||||
return ""
|
||||
|
||||
ltm_results = await self.ltm.asearch(task, latest_n=2)
|
||||
if not ltm_results:
|
||||
return None
|
||||
|
||||
formatted_results = [
|
||||
suggestion
|
||||
for result in ltm_results
|
||||
for suggestion in result["metadata"]["suggestions"]
|
||||
]
|
||||
formatted_results = list(dict.fromkeys(formatted_results))
|
||||
formatted_results = "\n".join([f"- {result}" for result in formatted_results]) # type: ignore # Incompatible types in assignment (expression has type "str", variable has type "list[str]")
|
||||
|
||||
return f"Historical Data:\n{formatted_results}" if ltm_results else ""
|
||||
|
||||
async def _afetch_entity_context(self, query: str) -> str:
|
||||
"""Fetch relevant entity information asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted entity information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.em is None:
|
||||
return ""
|
||||
|
||||
em_results = await self.em.asearch(query)
|
||||
formatted_results = "\n".join(
|
||||
[f"- {result['content']}" for result in em_results]
|
||||
)
|
||||
return f"Entities:\n{formatted_results}" if em_results else ""
|
||||
|
||||
async def _afetch_external_context(self, query: str) -> str:
|
||||
"""Fetch relevant information from External Memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
|
||||
Returns:
|
||||
Formatted information as bullet points, or empty string if none found.
|
||||
"""
|
||||
if self.exm is None:
|
||||
return ""
|
||||
|
||||
external_memories = await self.exm.asearch(query)
|
||||
|
||||
if not external_memories:
|
||||
return ""
|
||||
|
||||
formatted_memories = "\n".join(
|
||||
f"- {result['content']}" for result in external_memories
|
||||
)
|
||||
return f"External memories:\n{formatted_memories}"
|
||||
|
||||
@@ -26,7 +26,13 @@ class EntityMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -43,7 +49,7 @@ class EntityMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -170,7 +176,17 @@ class EntityMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search entity memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -217,6 +233,168 @@ class EntityMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: EntityMemoryItem | list[EntityMemoryItem],
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save entity items asynchronously.
|
||||
|
||||
Args:
|
||||
value: Single EntityMemoryItem or list of EntityMemoryItems to save.
|
||||
metadata: Optional metadata dict (not used, for signature compatibility).
|
||||
"""
|
||||
if not value:
|
||||
return
|
||||
|
||||
items = value if isinstance(value, list) else [value]
|
||||
is_batch = len(items) > 1
|
||||
|
||||
metadata = {"entity_count": len(items)} if is_batch else items[0].metadata
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
metadata=metadata,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
saved_count = 0
|
||||
errors: list[str | None] = []
|
||||
|
||||
async def save_single_item(item: EntityMemoryItem) -> tuple[bool, str | None]:
|
||||
"""Save a single item asynchronously."""
|
||||
try:
|
||||
if self._memory_provider == "mem0":
|
||||
data = f"""
|
||||
Remember details about the following entity:
|
||||
Name: {item.name}
|
||||
Type: {item.type}
|
||||
Entity Description: {item.description}
|
||||
"""
|
||||
else:
|
||||
data = f"{item.name}({item.type}): {item.description}"
|
||||
|
||||
await super(EntityMemory, self).asave(data, item.metadata)
|
||||
return True, None
|
||||
except Exception as e:
|
||||
return False, f"{item.name}: {e!s}"
|
||||
|
||||
try:
|
||||
for item in items:
|
||||
success, error = await save_single_item(item)
|
||||
if success:
|
||||
saved_count += 1
|
||||
else:
|
||||
errors.append(error)
|
||||
|
||||
if is_batch:
|
||||
emit_value = f"Saved {saved_count} entities"
|
||||
metadata = {"entity_count": saved_count, "errors": errors}
|
||||
else:
|
||||
emit_value = f"{items[0].name}({items[0].type}): {items[0].description}"
|
||||
metadata = items[0].metadata
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=emit_value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
if errors:
|
||||
raise Exception(
|
||||
f"Partial save: {len(errors)} failed out of {len(items)}"
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
fail_metadata = (
|
||||
{"entity_count": len(items), "saved": saved_count}
|
||||
if is_batch
|
||||
else items[0].metadata
|
||||
)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
metadata=fail_metadata,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search entity memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="entity_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="entity_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
self.storage.reset()
|
||||
|
||||
@@ -30,7 +30,7 @@ class ExternalMemory(Memory):
|
||||
def _configure_mem0(crew: Any, config: dict[str, Any]) -> Mem0Storage:
|
||||
from crewai.memory.storage.mem0_storage import Mem0Storage
|
||||
|
||||
return Mem0Storage(type="external", crew=crew, config=config)
|
||||
return Mem0Storage(type="external", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
|
||||
@staticmethod
|
||||
def external_supported_storages() -> dict[str, Any]:
|
||||
@@ -53,7 +53,10 @@ class ExternalMemory(Memory):
|
||||
if provider not in supported_storages:
|
||||
raise ValueError(f"Provider {provider} not supported")
|
||||
|
||||
return supported_storages[provider](crew, embedder_config.get("config", {}))
|
||||
storage: Storage = supported_storages[provider](
|
||||
crew, embedder_config.get("config", {})
|
||||
)
|
||||
return storage
|
||||
|
||||
def save(
|
||||
self,
|
||||
@@ -111,7 +114,17 @@ class ExternalMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search external memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -158,6 +171,124 @@ class ExternalMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to external memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ExternalMemoryItem(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
await super().asave(value=item.value, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search external memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await super().asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="external_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="external_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
self.storage.reset()
|
||||
|
||||
|
||||
@@ -24,7 +24,11 @@ class LongTermMemory(Memory):
|
||||
LongTermMemoryItem instances.
|
||||
"""
|
||||
|
||||
def __init__(self, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
storage: LTMSQLiteStorage | None = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
if not storage:
|
||||
storage = LTMSQLiteStorage(db_path=path) if path else LTMSQLiteStorage()
|
||||
super().__init__(storage=storage)
|
||||
@@ -48,7 +52,7 @@ class LongTermMemory(Memory):
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
self.storage.save( # type: ignore # BUG?: Unexpected keyword argument "task_description","score","datetime" for "save" of "Storage"
|
||||
self.storage.save(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
@@ -80,11 +84,20 @@ class LongTermMemory(Memory):
|
||||
)
|
||||
raise
|
||||
|
||||
def search( # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
def search( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]: # type: ignore # signature of "search" incompatible with supertype "Memory"
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -98,7 +111,7 @@ class LongTermMemory(Memory):
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = self.storage.load(task, latest_n) # type: ignore # BUG?: "Storage" has no attribute "load"
|
||||
results = self.storage.load(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -113,7 +126,118 @@ class LongTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(self, item: LongTermMemoryItem) -> None: # type: ignore[override]
|
||||
"""Save an item to long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
item: The LongTermMemoryItem to save.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
metadata = item.metadata
|
||||
metadata.update(
|
||||
{"agent": item.agent, "expected_output": item.expected_output}
|
||||
)
|
||||
await self.storage.asave(
|
||||
task_description=item.task,
|
||||
score=metadata["quality"],
|
||||
metadata=metadata,
|
||||
datetime=item.datetime,
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=item.task,
|
||||
metadata=item.metadata,
|
||||
agent_role=item.agent,
|
||||
error=str(e),
|
||||
source_type="long_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch( # type: ignore[override]
|
||||
self,
|
||||
task: str,
|
||||
latest_n: int = 3,
|
||||
) -> list[dict[str, Any]]:
|
||||
"""Search long-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
task: The task description to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=task,
|
||||
limit=latest_n,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.aload(task, latest_n)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=task,
|
||||
results=results,
|
||||
limit=latest_n,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="long_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return results or []
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -127,4 +251,5 @@ class LongTermMemory(Memory):
|
||||
raise
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset long-term memory."""
|
||||
self.storage.reset()
|
||||
|
||||
@@ -13,9 +13,7 @@ if TYPE_CHECKING:
|
||||
|
||||
|
||||
class Memory(BaseModel):
|
||||
"""
|
||||
Base class for memory, now supporting agent tags and generic metadata.
|
||||
"""
|
||||
"""Base class for memory, supporting agent tags and generic metadata."""
|
||||
|
||||
embedder_config: EmbedderConfig | dict[str, Any] | None = None
|
||||
crew: Any | None = None
|
||||
@@ -52,20 +50,72 @@ class Memory(BaseModel):
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
metadata = metadata or {}
|
||||
"""Save a value to memory.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
self.storage.save(value, metadata)
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
metadata = metadata or {}
|
||||
await self.storage.asave(value, metadata)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
return self.storage.search(
|
||||
"""Search memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search memory for relevant entries asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
results: list[Any] = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
return results
|
||||
|
||||
def set_crew(self, crew: Any) -> Memory:
|
||||
"""Set the crew for this memory instance."""
|
||||
self.crew = crew
|
||||
return self
|
||||
|
||||
@@ -30,7 +30,13 @@ class ShortTermMemory(Memory):
|
||||
|
||||
_memory_provider: str | None = PrivateAttr()
|
||||
|
||||
def __init__(self, crew=None, embedder_config=None, storage=None, path=None):
|
||||
def __init__(
|
||||
self,
|
||||
crew: Any = None,
|
||||
embedder_config: Any = None,
|
||||
storage: Any = None,
|
||||
path: str | None = None,
|
||||
) -> None:
|
||||
memory_provider = None
|
||||
if embedder_config and isinstance(embedder_config, dict):
|
||||
memory_provider = embedder_config.get("provider")
|
||||
@@ -47,7 +53,7 @@ class ShortTermMemory(Memory):
|
||||
if embedder_config and isinstance(embedder_config, dict)
|
||||
else None
|
||||
)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config)
|
||||
storage = Mem0Storage(type="short_term", crew=crew, config=config) # type: ignore[no-untyped-call]
|
||||
else:
|
||||
storage = (
|
||||
storage
|
||||
@@ -123,7 +129,17 @@ class ShortTermMemory(Memory):
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
):
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory for relevant entries.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
@@ -140,7 +156,7 @@ class ShortTermMemory(Memory):
|
||||
try:
|
||||
results = self.storage.search(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
) # type: ignore # BUG? The reference is to the parent class, but the parent class does not have this parameters
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
@@ -156,7 +172,130 @@ class ShortTermMemory(Memory):
|
||||
),
|
||||
)
|
||||
|
||||
return results
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryFailedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
value: Any,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Save a value to short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Optional metadata to associate with the value.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveStartedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
item = ShortTermMemoryItem(
|
||||
data=value,
|
||||
metadata=metadata,
|
||||
agent=self.agent.role if self.agent else None,
|
||||
)
|
||||
if self._memory_provider == "mem0":
|
||||
item.data = (
|
||||
f"Remember the following insights from Agent run: {item.data}"
|
||||
)
|
||||
|
||||
await super().asave(value=item.data, metadata=item.metadata)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveCompletedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
save_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemorySaveFailedEvent(
|
||||
value=value,
|
||||
metadata=metadata,
|
||||
error=str(e),
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
raise
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search short-term memory asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries.
|
||||
"""
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryStartedEvent(
|
||||
query=query,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
start_time = time.time()
|
||||
try:
|
||||
results = await self.storage.asearch(
|
||||
query=query, limit=limit, score_threshold=score_threshold
|
||||
)
|
||||
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
event=MemoryQueryCompletedEvent(
|
||||
query=query,
|
||||
results=results,
|
||||
limit=limit,
|
||||
score_threshold=score_threshold,
|
||||
query_time_ms=(time.time() - start_time) * 1000,
|
||||
source_type="short_term_memory",
|
||||
from_agent=self.agent,
|
||||
from_task=self.task,
|
||||
),
|
||||
)
|
||||
|
||||
return list(results)
|
||||
except Exception as e:
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
|
||||
@@ -3,29 +3,30 @@ from pathlib import Path
|
||||
import sqlite3
|
||||
from typing import Any
|
||||
|
||||
import aiosqlite
|
||||
|
||||
from crewai.utilities import Printer
|
||||
from crewai.utilities.paths import db_storage_path
|
||||
|
||||
|
||||
class LTMSQLiteStorage:
|
||||
"""
|
||||
An updated SQLite storage class for LTM data storage.
|
||||
"""
|
||||
"""SQLite storage class for long-term memory data."""
|
||||
|
||||
def __init__(self, db_path: str | None = None) -> None:
|
||||
"""Initialize the SQLite storage.
|
||||
|
||||
Args:
|
||||
db_path: Optional path to the database file.
|
||||
"""
|
||||
if db_path is None:
|
||||
# Get the parent directory of the default db path and create our db file there
|
||||
db_path = str(Path(db_storage_path()) / "long_term_memory_storage.db")
|
||||
self.db_path = db_path
|
||||
self._printer: Printer = Printer()
|
||||
# Ensure parent directory exists
|
||||
Path(self.db_path).parent.mkdir(parents=True, exist_ok=True)
|
||||
self._initialize_db()
|
||||
|
||||
def _initialize_db(self):
|
||||
"""
|
||||
Initializes the SQLite database and creates LTM table
|
||||
"""
|
||||
def _initialize_db(self) -> None:
|
||||
"""Initialize the SQLite database and create LTM table."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
cursor = conn.cursor()
|
||||
@@ -106,9 +107,7 @@ class LTMSQLiteStorage:
|
||||
)
|
||||
return None
|
||||
|
||||
def reset(
|
||||
self,
|
||||
) -> None:
|
||||
def reset(self) -> None:
|
||||
"""Resets the LTM table with error handling."""
|
||||
try:
|
||||
with sqlite3.connect(self.db_path) as conn:
|
||||
@@ -121,4 +120,87 @@ class LTMSQLiteStorage:
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return
|
||||
|
||||
async def asave(
|
||||
self,
|
||||
task_description: str,
|
||||
metadata: dict[str, Any],
|
||||
datetime: str,
|
||||
score: int | float,
|
||||
) -> None:
|
||||
"""Save data to the LTM table asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task.
|
||||
metadata: Metadata associated with the memory.
|
||||
datetime: Timestamp of the memory.
|
||||
score: Quality score of the memory.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute(
|
||||
"""
|
||||
INSERT INTO long_term_memories (task_description, metadata, datetime, score)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""",
|
||||
(task_description, json.dumps(metadata), datetime, score),
|
||||
)
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while saving to LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
async def aload(
|
||||
self, task_description: str, latest_n: int
|
||||
) -> list[dict[str, Any]] | None:
|
||||
"""Query the LTM table by task description asynchronously.
|
||||
|
||||
Args:
|
||||
task_description: Description of the task to search for.
|
||||
latest_n: Maximum number of results to return.
|
||||
|
||||
Returns:
|
||||
List of matching memory entries or None if error occurs.
|
||||
"""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
cursor = await conn.execute(
|
||||
f"""
|
||||
SELECT metadata, datetime, score
|
||||
FROM long_term_memories
|
||||
WHERE task_description = ?
|
||||
ORDER BY datetime DESC, score ASC
|
||||
LIMIT {latest_n}
|
||||
""", # nosec # noqa: S608
|
||||
(task_description,),
|
||||
)
|
||||
rows = await cursor.fetchall()
|
||||
if rows:
|
||||
return [
|
||||
{
|
||||
"metadata": json.loads(row[0]),
|
||||
"datetime": row[1],
|
||||
"score": row[2],
|
||||
}
|
||||
for row in rows
|
||||
]
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while querying LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
return None
|
||||
|
||||
async def areset(self) -> None:
|
||||
"""Reset the LTM table asynchronously."""
|
||||
try:
|
||||
async with aiosqlite.connect(self.db_path) as conn:
|
||||
await conn.execute("DELETE FROM long_term_memories")
|
||||
await conn.commit()
|
||||
except aiosqlite.Error as e:
|
||||
self._printer.print(
|
||||
content=f"MEMORY ERROR: An error occurred while deleting all rows in LTM: {e}",
|
||||
color="red",
|
||||
)
|
||||
|
||||
@@ -129,6 +129,12 @@ class RAGStorage(BaseRAGStorage):
|
||||
return f"{base_path}/{file_name}"
|
||||
|
||||
def save(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -167,6 +173,51 @@ class RAGStorage(BaseRAGStorage):
|
||||
f"Error during {self.type} save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
async def asave(self, value: Any, metadata: dict[str, Any]) -> None:
|
||||
"""Save a value to storage asynchronously.
|
||||
|
||||
Args:
|
||||
value: The value to save.
|
||||
metadata: Metadata to associate with the value.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
await client.aget_or_create_collection(collection_name=collection_name)
|
||||
|
||||
document: BaseRecord = {"content": value}
|
||||
if metadata:
|
||||
document["metadata"] = metadata
|
||||
|
||||
batch_size = None
|
||||
if (
|
||||
self.embedder_config
|
||||
and isinstance(self.embedder_config, dict)
|
||||
and "config" in self.embedder_config
|
||||
):
|
||||
nested_config = self.embedder_config["config"]
|
||||
if isinstance(nested_config, dict):
|
||||
batch_size = nested_config.get("batch_size")
|
||||
|
||||
if batch_size is not None:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name,
|
||||
documents=[document],
|
||||
batch_size=cast(int, batch_size),
|
||||
)
|
||||
else:
|
||||
await client.aadd_documents(
|
||||
collection_name=collection_name, documents=[document]
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async save: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
|
||||
def search(
|
||||
self,
|
||||
query: str,
|
||||
@@ -174,6 +225,17 @@ class RAGStorage(BaseRAGStorage):
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
@@ -194,6 +256,44 @@ class RAGStorage(BaseRAGStorage):
|
||||
)
|
||||
return []
|
||||
|
||||
async def asearch(
|
||||
self,
|
||||
query: str,
|
||||
limit: int = 5,
|
||||
filter: dict[str, Any] | None = None,
|
||||
score_threshold: float = 0.6,
|
||||
) -> list[Any]:
|
||||
"""Search for matching entries in storage asynchronously.
|
||||
|
||||
Args:
|
||||
query: The search query.
|
||||
limit: Maximum number of results to return.
|
||||
filter: Optional metadata filter.
|
||||
score_threshold: Minimum similarity score for results.
|
||||
|
||||
Returns:
|
||||
List of matching entries.
|
||||
"""
|
||||
try:
|
||||
client = self._get_client()
|
||||
collection_name = (
|
||||
f"memory_{self.type}_{self.agents}"
|
||||
if self.agents
|
||||
else f"memory_{self.type}"
|
||||
)
|
||||
return await client.asearch(
|
||||
collection_name=collection_name,
|
||||
query=query,
|
||||
limit=limit,
|
||||
metadata_filter=filter,
|
||||
score_threshold=score_threshold,
|
||||
)
|
||||
except Exception as e:
|
||||
logging.error(
|
||||
f"Error during {self.type} async search: {e!s}\n{traceback.format_exc()}"
|
||||
)
|
||||
return []
|
||||
|
||||
def reset(self) -> None:
|
||||
try:
|
||||
client = self._get_client()
|
||||
|
||||
@@ -1,21 +1,35 @@
|
||||
"""HuggingFace embeddings provider."""
|
||||
|
||||
from chromadb.utils.embedding_functions.huggingface_embedding_function import (
|
||||
HuggingFaceEmbeddingServer,
|
||||
HuggingFaceEmbeddingFunction,
|
||||
)
|
||||
from pydantic import AliasChoices, Field
|
||||
|
||||
from crewai.rag.core.base_embeddings_provider import BaseEmbeddingsProvider
|
||||
|
||||
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingServer]):
|
||||
"""HuggingFace embeddings provider."""
|
||||
class HuggingFaceProvider(BaseEmbeddingsProvider[HuggingFaceEmbeddingFunction]):
|
||||
"""HuggingFace embeddings provider for the HuggingFace Inference API."""
|
||||
|
||||
embedding_callable: type[HuggingFaceEmbeddingServer] = Field(
|
||||
default=HuggingFaceEmbeddingServer,
|
||||
embedding_callable: type[HuggingFaceEmbeddingFunction] = Field(
|
||||
default=HuggingFaceEmbeddingFunction,
|
||||
description="HuggingFace embedding function class",
|
||||
)
|
||||
url: str = Field(
|
||||
description="HuggingFace API URL",
|
||||
validation_alias=AliasChoices("EMBEDDINGS_HUGGINGFACE_URL", "HUGGINGFACE_URL"),
|
||||
api_key: str | None = Field(
|
||||
default=None,
|
||||
description="HuggingFace API key",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_API_KEY",
|
||||
"HUGGINGFACE_API_KEY",
|
||||
"HF_TOKEN",
|
||||
),
|
||||
)
|
||||
model_name: str = Field(
|
||||
default="sentence-transformers/all-MiniLM-L6-v2",
|
||||
description="Model name to use for embeddings",
|
||||
validation_alias=AliasChoices(
|
||||
"EMBEDDINGS_HUGGINGFACE_MODEL_NAME",
|
||||
"HUGGINGFACE_MODEL_NAME",
|
||||
"model",
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1,6 +1,6 @@
|
||||
"""Type definitions for HuggingFace embedding providers."""
|
||||
|
||||
from typing import Literal
|
||||
from typing import Annotated, Literal
|
||||
|
||||
from typing_extensions import Required, TypedDict
|
||||
|
||||
@@ -8,7 +8,11 @@ from typing_extensions import Required, TypedDict
|
||||
class HuggingFaceProviderConfig(TypedDict, total=False):
|
||||
"""Configuration for HuggingFace provider."""
|
||||
|
||||
url: str
|
||||
api_key: str
|
||||
model: Annotated[
|
||||
str, "sentence-transformers/all-MiniLM-L6-v2"
|
||||
] # alias for model_name for backward compat
|
||||
model_name: Annotated[str, "sentence-transformers/all-MiniLM-L6-v2"]
|
||||
|
||||
|
||||
class HuggingFaceProviderSpec(TypedDict, total=False):
|
||||
|
||||
@@ -497,6 +497,107 @@ class Task(BaseModel):
|
||||
result = self._execute_core(agent, context, tools)
|
||||
future.set_result(result)
|
||||
|
||||
async def aexecute_sync(
|
||||
self,
|
||||
agent: BaseAgent | None = None,
|
||||
context: str | None = None,
|
||||
tools: list[BaseTool] | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Execute the task asynchronously using native async/await."""
|
||||
return await self._aexecute_core(agent, context, tools)
|
||||
|
||||
async def _aexecute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
context: str | None,
|
||||
tools: list[Any] | None,
|
||||
) -> TaskOutput:
|
||||
"""Run the core execution logic of the task asynchronously."""
|
||||
try:
|
||||
agent = agent or self.agent
|
||||
self.agent = agent
|
||||
if not agent:
|
||||
raise Exception(
|
||||
f"The task '{self.description}' has no agent assigned, therefore it can't be executed directly and should be executed in a Crew using a specific process that support that, like hierarchical."
|
||||
)
|
||||
|
||||
self.start_time = datetime.datetime.now()
|
||||
|
||||
self.prompt_context = context
|
||||
tools = tools or self.tools or []
|
||||
|
||||
self.processed_by_agents.add(agent.role)
|
||||
crewai_event_bus.emit(self, TaskStartedEvent(context=context, task=self)) # type: ignore[no-untyped-call]
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
if not self._guardrails and not self._guardrail:
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
else:
|
||||
pydantic_output, json_output = None, None
|
||||
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
for idx, guardrail in enumerate(self._guardrails):
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=guardrail,
|
||||
guardrail_index=idx,
|
||||
)
|
||||
|
||||
if self._guardrail:
|
||||
task_output = await self._ainvoke_guardrail_function(
|
||||
task_output=task_output,
|
||||
agent=agent,
|
||||
tools=tools,
|
||||
guardrail=self._guardrail,
|
||||
)
|
||||
|
||||
self.output = task_output
|
||||
self.end_time = datetime.datetime.now()
|
||||
|
||||
if self.callback:
|
||||
self.callback(self.output)
|
||||
|
||||
crew = self.agent.crew # type: ignore[union-attr]
|
||||
if crew and crew.task_callback and crew.task_callback != self.callback:
|
||||
crew.task_callback(self.output)
|
||||
|
||||
if self.output_file:
|
||||
content = (
|
||||
json_output
|
||||
if json_output
|
||||
else (
|
||||
pydantic_output.model_dump_json() if pydantic_output else result
|
||||
)
|
||||
)
|
||||
self._save_file(content)
|
||||
crewai_event_bus.emit(
|
||||
self,
|
||||
TaskCompletedEvent(output=task_output, task=self), # type: ignore[no-untyped-call]
|
||||
)
|
||||
return task_output
|
||||
except Exception as e:
|
||||
self.end_time = datetime.datetime.now()
|
||||
crewai_event_bus.emit(self, TaskFailedEvent(error=str(e), task=self)) # type: ignore[no-untyped-call]
|
||||
raise e # Re-raise the exception after emitting the event
|
||||
|
||||
def _execute_core(
|
||||
self,
|
||||
agent: BaseAgent | None,
|
||||
@@ -539,7 +640,7 @@ class Task(BaseModel):
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
if self._guardrails:
|
||||
@@ -950,7 +1051,103 @@ Follow these guidelines:
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages,
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
async def _ainvoke_guardrail_function(
|
||||
self,
|
||||
task_output: TaskOutput,
|
||||
agent: BaseAgent,
|
||||
tools: list[BaseTool],
|
||||
guardrail: GuardrailCallable | None,
|
||||
guardrail_index: int | None = None,
|
||||
) -> TaskOutput:
|
||||
"""Invoke the guardrail function asynchronously."""
|
||||
if not guardrail:
|
||||
return task_output
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count = self._guardrail_retry_counts.get(guardrail_index, 0)
|
||||
else:
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
max_attempts = self.guardrail_max_retries + 1
|
||||
|
||||
for attempt in range(max_attempts):
|
||||
guardrail_result = process_guardrail(
|
||||
output=task_output,
|
||||
guardrail=guardrail,
|
||||
retry_count=current_retry_count,
|
||||
event_source=self,
|
||||
from_task=self,
|
||||
from_agent=agent,
|
||||
)
|
||||
|
||||
if guardrail_result.success:
|
||||
if guardrail_result.result is None:
|
||||
raise Exception(
|
||||
"Task guardrail returned None as result. This is not allowed."
|
||||
)
|
||||
|
||||
if isinstance(guardrail_result.result, str):
|
||||
task_output.raw = guardrail_result.result
|
||||
pydantic_output, json_output = self._export_output(
|
||||
guardrail_result.result
|
||||
)
|
||||
task_output.pydantic = pydantic_output
|
||||
task_output.json_dict = json_output
|
||||
elif isinstance(guardrail_result.result, TaskOutput):
|
||||
task_output = guardrail_result.result
|
||||
|
||||
return task_output
|
||||
|
||||
if attempt >= self.guardrail_max_retries:
|
||||
guardrail_name = (
|
||||
f"guardrail {guardrail_index}"
|
||||
if guardrail_index is not None
|
||||
else "guardrail"
|
||||
)
|
||||
raise Exception(
|
||||
f"Task failed {guardrail_name} validation after {self.guardrail_max_retries} retries. "
|
||||
f"Last error: {guardrail_result.error}"
|
||||
)
|
||||
|
||||
if guardrail_index is not None:
|
||||
current_retry_count += 1
|
||||
self._guardrail_retry_counts[guardrail_index] = current_retry_count
|
||||
else:
|
||||
self.retry_count += 1
|
||||
current_retry_count = self.retry_count
|
||||
|
||||
context = self.i18n.errors("validation_error").format(
|
||||
guardrail_result_error=guardrail_result.error,
|
||||
task_output=task_output.raw,
|
||||
)
|
||||
printer = Printer()
|
||||
printer.print(
|
||||
content=f"Guardrail {guardrail_index if guardrail_index is not None else ''} blocked (attempt {attempt + 1}/{max_attempts}), retrying due to: {guardrail_result.error}\n",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
result = await agent.aexecute_task(
|
||||
task=self,
|
||||
context=context,
|
||||
tools=tools,
|
||||
)
|
||||
|
||||
pydantic_output, json_output = self._export_output(result)
|
||||
task_output = TaskOutput(
|
||||
name=self.name or self.description,
|
||||
description=self.description,
|
||||
expected_output=self.expected_output,
|
||||
raw=result,
|
||||
pydantic=pydantic_output,
|
||||
json_dict=json_output,
|
||||
agent=agent.role,
|
||||
output_format=self._get_output_format(),
|
||||
messages=agent.last_messages, # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
return task_output
|
||||
|
||||
@@ -392,9 +392,7 @@ class Telemetry:
|
||||
self._add_attribute(span, "platform_system", platform.system())
|
||||
self._add_attribute(span, "platform_version", platform.version())
|
||||
self._add_attribute(span, "cpus", os.cpu_count())
|
||||
self._add_attribute(
|
||||
span, "crew_inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
|
||||
else:
|
||||
self._add_attribute(
|
||||
span,
|
||||
@@ -707,9 +705,7 @@ class Telemetry:
|
||||
self._add_attribute(span, "model_name", model_name)
|
||||
|
||||
if crew.share_crew:
|
||||
self._add_attribute(
|
||||
span, "inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "inputs", json.dumps(inputs or {}))
|
||||
|
||||
close_span(span)
|
||||
|
||||
@@ -814,9 +810,7 @@ class Telemetry:
|
||||
add_crew_attributes(
|
||||
span, crew, self._add_attribute, include_fingerprint=False
|
||||
)
|
||||
self._add_attribute(
|
||||
span, "crew_inputs", json.dumps(inputs) if inputs else None
|
||||
)
|
||||
self._add_attribute(span, "crew_inputs", json.dumps(inputs or {}))
|
||||
self._add_attribute(
|
||||
span,
|
||||
"crew_agents",
|
||||
|
||||
@@ -2,9 +2,18 @@ from __future__ import annotations
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
import asyncio
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Awaitable, Callable
|
||||
from inspect import signature
|
||||
from typing import Any, cast, get_args, get_origin
|
||||
from typing import (
|
||||
Any,
|
||||
Generic,
|
||||
ParamSpec,
|
||||
TypeVar,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
)
|
||||
|
||||
from pydantic import (
|
||||
BaseModel,
|
||||
@@ -14,6 +23,7 @@ from pydantic import (
|
||||
create_model,
|
||||
field_validator,
|
||||
)
|
||||
from typing_extensions import TypeIs
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.utilities.printer import Printer
|
||||
@@ -21,6 +31,19 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
_printer = Printer()
|
||||
|
||||
P = ParamSpec("P")
|
||||
R = TypeVar("R", covariant=True)
|
||||
|
||||
|
||||
def _is_async_callable(func: Callable[..., Any]) -> bool:
|
||||
"""Check if a callable is async."""
|
||||
return asyncio.iscoroutinefunction(func)
|
||||
|
||||
|
||||
def _is_awaitable(value: R | Awaitable[R]) -> TypeIs[Awaitable[R]]:
|
||||
"""Type narrowing check for awaitable values."""
|
||||
return asyncio.iscoroutine(value) or asyncio.isfuture(value)
|
||||
|
||||
|
||||
class EnvVar(BaseModel):
|
||||
name: str
|
||||
@@ -55,7 +78,7 @@ class BaseTool(BaseModel, ABC):
|
||||
default=False, description="Flag to check if the description has been updated."
|
||||
)
|
||||
|
||||
cache_function: Callable = Field(
|
||||
cache_function: Callable[..., bool] = Field(
|
||||
default=lambda _args=None, _result=None: True,
|
||||
description="Function that will be used to determine if the tool should be cached, should return a boolean. If None, the tool will be cached.",
|
||||
)
|
||||
@@ -123,6 +146,35 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
return result
|
||||
|
||||
async def arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Execute the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments to pass to the tool.
|
||||
**kwargs: Keyword arguments to pass to the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
async def _arun(
|
||||
self,
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Async implementation of the tool. Override for async support."""
|
||||
raise NotImplementedError(
|
||||
f"{self.__class__.__name__} does not implement _arun. "
|
||||
"Override _arun for async support or use run() for sync execution."
|
||||
)
|
||||
|
||||
def reset_usage_count(self) -> None:
|
||||
"""Reset the current usage count to zero."""
|
||||
self.current_usage_count = 0
|
||||
@@ -133,7 +185,17 @@ class BaseTool(BaseModel, ABC):
|
||||
*args: Any,
|
||||
**kwargs: Any,
|
||||
) -> Any:
|
||||
"""Here goes the actual implementation of the tool."""
|
||||
"""Sync implementation of the tool.
|
||||
|
||||
Subclasses must implement this method for synchronous execution.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
|
||||
def to_structured_tool(self) -> CrewStructuredTool:
|
||||
"""Convert this tool to a CrewStructuredTool instance."""
|
||||
@@ -239,21 +301,90 @@ class BaseTool(BaseModel, ABC):
|
||||
|
||||
if args:
|
||||
args_str = ", ".join(BaseTool._get_arg_annotations(arg) for arg in args)
|
||||
return f"{origin.__name__}[{args_str}]"
|
||||
return str(f"{origin.__name__}[{args_str}]")
|
||||
|
||||
return origin.__name__
|
||||
return str(origin.__name__)
|
||||
|
||||
|
||||
class Tool(BaseTool):
|
||||
"""The function that will be executed when the tool is called."""
|
||||
class Tool(BaseTool, Generic[P, R]):
|
||||
"""Tool that wraps a callable function.
|
||||
|
||||
func: Callable
|
||||
|
||||
def _run(self, *args: Any, **kwargs: Any) -> Any:
|
||||
return self.func(*args, **kwargs)
|
||||
Type Parameters:
|
||||
P: ParamSpec capturing the function's parameters.
|
||||
R: The return type of the function.
|
||||
"""
|
||||
|
||||
func: Callable[P, R | Awaitable[R]]
|
||||
|
||||
def run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the tool synchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
_printer.print(f"Using Tool: {self.name}", color="cyan")
|
||||
result = self.func(*args, **kwargs)
|
||||
|
||||
if asyncio.iscoroutine(result):
|
||||
result = asyncio.run(result)
|
||||
|
||||
self.current_usage_count += 1
|
||||
return result # type: ignore[return-value]
|
||||
|
||||
def _run(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
The result of the function execution.
|
||||
"""
|
||||
return self.func(*args, **kwargs) # type: ignore[return-value]
|
||||
|
||||
async def arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the tool asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the tool.
|
||||
**kwargs: Keyword arguments for the tool.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution.
|
||||
"""
|
||||
result = await self._arun(*args, **kwargs)
|
||||
self.current_usage_count += 1
|
||||
return result
|
||||
|
||||
async def _arun(self, *args: P.args, **kwargs: P.kwargs) -> R:
|
||||
"""Executes the wrapped function asynchronously.
|
||||
|
||||
Args:
|
||||
*args: Positional arguments for the function.
|
||||
**kwargs: Keyword arguments for the function.
|
||||
|
||||
Returns:
|
||||
The result of the async function execution.
|
||||
|
||||
Raises:
|
||||
NotImplementedError: If the wrapped function is not async.
|
||||
"""
|
||||
result = self.func(*args, **kwargs)
|
||||
if _is_awaitable(result):
|
||||
return await result
|
||||
raise NotImplementedError(
|
||||
f"{self.name} does not have an async function. "
|
||||
"Use run() for sync execution or provide an async function."
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_langchain(cls, tool: Any) -> Tool:
|
||||
def from_langchain(cls, tool: Any) -> Tool[..., Any]:
|
||||
"""Create a Tool instance from a CrewStructuredTool.
|
||||
|
||||
This method takes a CrewStructuredTool object and converts it into a
|
||||
@@ -261,10 +392,10 @@ class Tool(BaseTool):
|
||||
attribute and infers the argument schema if not explicitly provided.
|
||||
|
||||
Args:
|
||||
tool (Any): The CrewStructuredTool object to be converted.
|
||||
tool: The CrewStructuredTool object to be converted.
|
||||
|
||||
Returns:
|
||||
Tool: A new Tool instance created from the provided CrewStructuredTool.
|
||||
A new Tool instance created from the provided CrewStructuredTool.
|
||||
|
||||
Raises:
|
||||
ValueError: If the provided tool does not have a callable 'func' attribute.
|
||||
@@ -308,37 +439,83 @@ class Tool(BaseTool):
|
||||
def to_langchain(
|
||||
tools: list[BaseTool | CrewStructuredTool],
|
||||
) -> list[CrewStructuredTool]:
|
||||
"""Convert a list of tools to CrewStructuredTool instances."""
|
||||
return [t.to_structured_tool() if isinstance(t, BaseTool) else t for t in tools]
|
||||
|
||||
|
||||
P2 = ParamSpec("P2")
|
||||
R2 = TypeVar("R2")
|
||||
|
||||
|
||||
@overload
|
||||
def tool(func: Callable[P2, R2], /) -> Tool[P2, R2]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
*args, result_as_answer: bool = False, max_usage_count: int | None = None
|
||||
) -> Callable:
|
||||
"""
|
||||
Decorator to create a tool from a function.
|
||||
name: str,
|
||||
/,
|
||||
*,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def tool(
|
||||
*,
|
||||
result_as_answer: bool = ...,
|
||||
max_usage_count: int | None = ...,
|
||||
) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]: ...
|
||||
|
||||
|
||||
def tool(
|
||||
*args: Callable[P2, R2] | str,
|
||||
result_as_answer: bool = False,
|
||||
max_usage_count: int | None = None,
|
||||
) -> Tool[P2, R2] | Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
"""Decorator to create a Tool from a function.
|
||||
|
||||
Can be used in three ways:
|
||||
1. @tool - decorator without arguments, uses function name
|
||||
2. @tool("name") - decorator with custom name
|
||||
3. @tool(result_as_answer=True) - decorator with options
|
||||
|
||||
Args:
|
||||
*args: Positional arguments, either the function to decorate or the tool name.
|
||||
result_as_answer: Flag to indicate if the tool result should be used as the final agent answer.
|
||||
max_usage_count: Maximum number of times this tool can be used. None means unlimited usage.
|
||||
*args: Either the function to decorate or a custom tool name.
|
||||
result_as_answer: If True, the tool result becomes the final agent answer.
|
||||
max_usage_count: Maximum times this tool can be used. None means unlimited.
|
||||
|
||||
Returns:
|
||||
A Tool instance.
|
||||
|
||||
Example:
|
||||
@tool
|
||||
def greet(name: str) -> str:
|
||||
'''Greet someone.'''
|
||||
return f"Hello, {name}!"
|
||||
|
||||
result = greet.run("World")
|
||||
"""
|
||||
|
||||
def _make_with_name(tool_name: str) -> Callable:
|
||||
def _make_tool(f: Callable) -> BaseTool:
|
||||
def _make_with_name(tool_name: str) -> Callable[[Callable[P2, R2]], Tool[P2, R2]]:
|
||||
def _make_tool(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
||||
if f.__doc__ is None:
|
||||
raise ValueError("Function must have a docstring")
|
||||
if f.__annotations__ is None:
|
||||
|
||||
func_annotations = getattr(f, "__annotations__", None)
|
||||
if func_annotations is None:
|
||||
raise ValueError("Function must have type annotations")
|
||||
|
||||
class_name = "".join(tool_name.split()).title()
|
||||
args_schema = cast(
|
||||
tool_args_schema = cast(
|
||||
type[PydanticBaseModel],
|
||||
type(
|
||||
class_name,
|
||||
(PydanticBaseModel,),
|
||||
{
|
||||
"__annotations__": {
|
||||
k: v for k, v in f.__annotations__.items() if k != "return"
|
||||
k: v for k, v in func_annotations.items() if k != "return"
|
||||
},
|
||||
},
|
||||
),
|
||||
@@ -348,10 +525,9 @@ def tool(
|
||||
name=tool_name,
|
||||
description=f.__doc__,
|
||||
func=f,
|
||||
args_schema=args_schema,
|
||||
args_schema=tool_args_schema,
|
||||
result_as_answer=result_as_answer,
|
||||
max_usage_count=max_usage_count,
|
||||
current_usage_count=0,
|
||||
)
|
||||
|
||||
return _make_tool
|
||||
@@ -360,4 +536,10 @@ def tool(
|
||||
return _make_with_name(args[0].__name__)(args[0])
|
||||
if len(args) == 1 and isinstance(args[0], str):
|
||||
return _make_with_name(args[0])
|
||||
if len(args) == 0:
|
||||
|
||||
def decorator(f: Callable[P2, R2]) -> Tool[P2, R2]:
|
||||
return _make_with_name(f.__name__)(f)
|
||||
|
||||
return decorator
|
||||
raise ValueError("Invalid arguments")
|
||||
|
||||
@@ -160,6 +160,251 @@ class ToolUsage:
|
||||
|
||||
return f"{self._use(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||
|
||||
async def ause(
|
||||
self, calling: ToolCalling | InstructorToolCalling, tool_string: str
|
||||
) -> str:
|
||||
"""Execute a tool asynchronously.
|
||||
|
||||
Args:
|
||||
calling: The tool calling information.
|
||||
tool_string: The raw tool string from the agent.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution as a string.
|
||||
"""
|
||||
if isinstance(calling, ToolUsageError):
|
||||
error = calling.message
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return error
|
||||
|
||||
try:
|
||||
tool = self._select_tool(calling.tool_name)
|
||||
except Exception as e:
|
||||
error = getattr(e, "message", str(e))
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
return error
|
||||
|
||||
if (
|
||||
isinstance(tool, CrewStructuredTool)
|
||||
and tool.name == self._i18n.tools("add_image")["name"] # type: ignore
|
||||
):
|
||||
try:
|
||||
return await self._ause(
|
||||
tool_string=tool_string, tool=tool, calling=calling
|
||||
)
|
||||
except Exception as e:
|
||||
error = getattr(e, "message", str(e))
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(content=f"\n\n{error}\n", color="red")
|
||||
return error
|
||||
|
||||
return (
|
||||
f"{await self._ause(tool_string=tool_string, tool=tool, calling=calling)}"
|
||||
)
|
||||
|
||||
async def _ause(
|
||||
self,
|
||||
tool_string: str,
|
||||
tool: CrewStructuredTool,
|
||||
calling: ToolCalling | InstructorToolCalling,
|
||||
) -> str:
|
||||
"""Internal async tool execution implementation.
|
||||
|
||||
Args:
|
||||
tool_string: The raw tool string from the agent.
|
||||
tool: The tool to execute.
|
||||
calling: The tool calling information.
|
||||
|
||||
Returns:
|
||||
The result of the tool execution as a string.
|
||||
"""
|
||||
if self._check_tool_repeated_usage(calling=calling):
|
||||
try:
|
||||
result = self._i18n.errors("task_repeated_usage").format(
|
||||
tool_names=self.tools_names
|
||||
)
|
||||
self._telemetry.tool_repeated_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
return self._format_result(result=result)
|
||||
except Exception:
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
if self.agent:
|
||||
event_data = {
|
||||
"agent_key": self.agent.key,
|
||||
"agent_role": self.agent.role,
|
||||
"tool_name": self.action.tool,
|
||||
"tool_args": self.action.tool_input,
|
||||
"tool_class": self.action.tool,
|
||||
"agent": self.agent,
|
||||
}
|
||||
|
||||
if self.agent.fingerprint: # type: ignore
|
||||
event_data.update(self.agent.fingerprint) # type: ignore
|
||||
if self.task:
|
||||
event_data["task_name"] = self.task.name or self.task.description
|
||||
event_data["task_id"] = str(self.task.id)
|
||||
crewai_event_bus.emit(self, ToolUsageStartedEvent(**event_data))
|
||||
|
||||
started_at = time.time()
|
||||
from_cache = False
|
||||
result = None # type: ignore
|
||||
|
||||
if self.tools_handler and self.tools_handler.cache:
|
||||
input_str = ""
|
||||
if calling.arguments:
|
||||
if isinstance(calling.arguments, dict):
|
||||
input_str = json.dumps(calling.arguments)
|
||||
else:
|
||||
input_str = str(calling.arguments)
|
||||
|
||||
result = self.tools_handler.cache.read(
|
||||
tool=calling.tool_name, input=input_str
|
||||
) # type: ignore
|
||||
from_cache = result is not None
|
||||
|
||||
available_tool = next(
|
||||
(
|
||||
available_tool
|
||||
for available_tool in self.tools
|
||||
if available_tool.name == tool.name
|
||||
),
|
||||
None,
|
||||
)
|
||||
|
||||
usage_limit_error = self._check_usage_limit(available_tool, tool.name)
|
||||
if usage_limit_error:
|
||||
try:
|
||||
result = usage_limit_error
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
return self._format_result(result=result)
|
||||
except Exception:
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
|
||||
if result is None:
|
||||
try:
|
||||
if calling.tool_name in [
|
||||
"Delegate work to coworker",
|
||||
"Ask question to coworker",
|
||||
]:
|
||||
coworker = (
|
||||
calling.arguments.get("coworker") if calling.arguments else None
|
||||
)
|
||||
if self.task:
|
||||
self.task.increment_delegations(coworker)
|
||||
|
||||
if calling.arguments:
|
||||
try:
|
||||
acceptable_args = tool.args_schema.model_json_schema()[
|
||||
"properties"
|
||||
].keys()
|
||||
arguments = {
|
||||
k: v
|
||||
for k, v in calling.arguments.items()
|
||||
if k in acceptable_args
|
||||
}
|
||||
arguments = self._add_fingerprint_metadata(arguments)
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
except Exception:
|
||||
arguments = calling.arguments
|
||||
arguments = self._add_fingerprint_metadata(arguments)
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
else:
|
||||
arguments = self._add_fingerprint_metadata({})
|
||||
result = await tool.ainvoke(input=arguments)
|
||||
except Exception as e:
|
||||
self.on_tool_error(tool=tool, tool_calling=calling, e=e)
|
||||
self._run_attempts += 1
|
||||
if self._run_attempts > self._max_parsing_attempts:
|
||||
self._telemetry.tool_usage_error(llm=self.function_calling_llm)
|
||||
error_message = self._i18n.errors("tool_usage_exception").format(
|
||||
error=e, tool=tool.name, tool_inputs=tool.description
|
||||
)
|
||||
error = ToolUsageError(
|
||||
f"\n{error_message}.\nMoving on then. {self._i18n.slice('format').format(tool_names=self.tools_names)}"
|
||||
).message
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
if self.agent and self.agent.verbose:
|
||||
self._printer.print(
|
||||
content=f"\n\n{error_message}\n", color="red"
|
||||
)
|
||||
return error
|
||||
|
||||
if self.task:
|
||||
self.task.increment_tools_errors()
|
||||
return await self.ause(calling=calling, tool_string=tool_string)
|
||||
|
||||
if self.tools_handler:
|
||||
should_cache = True
|
||||
if (
|
||||
hasattr(available_tool, "cache_function")
|
||||
and available_tool.cache_function
|
||||
):
|
||||
should_cache = available_tool.cache_function(
|
||||
calling.arguments, result
|
||||
)
|
||||
|
||||
self.tools_handler.on_tool_use(
|
||||
calling=calling, output=result, should_cache=should_cache
|
||||
)
|
||||
|
||||
self._telemetry.tool_usage(
|
||||
llm=self.function_calling_llm,
|
||||
tool_name=tool.name,
|
||||
attempts=self._run_attempts,
|
||||
)
|
||||
result = self._format_result(result=result)
|
||||
data = {
|
||||
"result": result,
|
||||
"tool_name": tool.name,
|
||||
"tool_args": calling.arguments,
|
||||
}
|
||||
|
||||
self.on_tool_use_finished(
|
||||
tool=tool,
|
||||
tool_calling=calling,
|
||||
from_cache=from_cache,
|
||||
started_at=started_at,
|
||||
result=result,
|
||||
)
|
||||
|
||||
if (
|
||||
hasattr(available_tool, "result_as_answer")
|
||||
and available_tool.result_as_answer # type: ignore
|
||||
):
|
||||
result_as_answer = available_tool.result_as_answer # type: ignore
|
||||
data["result_as_answer"] = result_as_answer # type: ignore
|
||||
|
||||
if self.agent and hasattr(self.agent, "tools_results"):
|
||||
self.agent.tools_results.append(data)
|
||||
|
||||
if available_tool and hasattr(available_tool, "current_usage_count"):
|
||||
available_tool.current_usage_count += 1
|
||||
if (
|
||||
hasattr(available_tool, "max_usage_count")
|
||||
and available_tool.max_usage_count is not None
|
||||
):
|
||||
self._printer.print(
|
||||
content=f"Tool '{available_tool.name}' usage: {available_tool.current_usage_count}/{available_tool.max_usage_count}",
|
||||
color="blue",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
def _use(
|
||||
self,
|
||||
tool_string: str,
|
||||
|
||||
@@ -237,22 +237,22 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call
|
||||
messages: The messages to send to the LLM
|
||||
callbacks: List of callbacks for the LLM call
|
||||
printer: Printer instance for output
|
||||
from_task: Optional task context for the LLM call
|
||||
from_agent: Optional agent context for the LLM call
|
||||
response_model: Optional Pydantic model for structured outputs
|
||||
executor_context: Optional executor context for hook invocation
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
@@ -284,6 +284,60 @@ def get_llm_response(
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
async def aget_llm_response(
|
||||
llm: LLM | BaseLLM,
|
||||
messages: list[LLMMessage],
|
||||
callbacks: list[TokenCalcHandler],
|
||||
printer: Printer,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM asynchronously and return the response.
|
||||
|
||||
Args:
|
||||
llm: The LLM instance to call.
|
||||
messages: The messages to send to the LLM.
|
||||
callbacks: List of callbacks for the LLM call.
|
||||
printer: Printer instance for output.
|
||||
from_task: Optional task context for the LLM call.
|
||||
from_agent: Optional agent context for the LLM call.
|
||||
response_model: Optional Pydantic model for structured outputs.
|
||||
executor_context: Optional executor context for hook invocation.
|
||||
|
||||
Returns:
|
||||
The response from the LLM as a string.
|
||||
|
||||
Raises:
|
||||
Exception: If an error occurs.
|
||||
ValueError: If the response is None or empty.
|
||||
"""
|
||||
if executor_context is not None:
|
||||
if not _setup_before_llm_call_hooks(executor_context, printer):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
messages = executor_context.messages
|
||||
|
||||
try:
|
||||
answer = await llm.acall(
|
||||
messages,
|
||||
callbacks=callbacks,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent, # type: ignore[arg-type]
|
||||
response_model=response_model,
|
||||
)
|
||||
except Exception as e:
|
||||
raise e
|
||||
if not answer:
|
||||
printer.print(
|
||||
content="Received None or empty response from LLM call.",
|
||||
color="red",
|
||||
)
|
||||
raise ValueError("Invalid response from LLM call - None or empty.")
|
||||
|
||||
return _setup_after_llm_call_hooks(executor_context, answer, printer)
|
||||
|
||||
|
||||
def process_llm_response(
|
||||
answer: str, use_stop_words: bool
|
||||
) -> AgentAction | AgentFinish:
|
||||
@@ -673,7 +727,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None, printer: Printer
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
|
||||
) -> bool:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
@@ -723,7 +777,7 @@ def _setup_before_llm_call_hooks(
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
|
||||
@@ -26,6 +26,138 @@ if TYPE_CHECKING:
|
||||
from crewai.task import Task
|
||||
|
||||
|
||||
async def aexecute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
i18n: I18N,
|
||||
agent_key: str | None = None,
|
||||
agent_role: str | None = None,
|
||||
tools_handler: ToolsHandler | None = None,
|
||||
task: Task | None = None,
|
||||
agent: Agent | BaseAgent | None = None,
|
||||
function_calling_llm: BaseLLM | LLM | None = None,
|
||||
fingerprint_context: dict[str, str] | None = None,
|
||||
crew: Crew | None = None,
|
||||
) -> ToolResult:
|
||||
"""Execute a tool asynchronously and check if the result should be a final answer.
|
||||
|
||||
This is the async version of execute_tool_and_check_finality. It integrates tool
|
||||
hooks for before and after tool execution, allowing programmatic interception
|
||||
and modification of tool calls.
|
||||
|
||||
Args:
|
||||
agent_action: The action containing the tool to execute.
|
||||
tools: List of available tools.
|
||||
i18n: Internationalization settings.
|
||||
agent_key: Optional key for event emission.
|
||||
agent_role: Optional role for event emission.
|
||||
tools_handler: Optional tools handler for tool execution.
|
||||
task: Optional task for tool execution.
|
||||
agent: Optional agent instance for tool execution.
|
||||
function_calling_llm: Optional LLM for function calling.
|
||||
fingerprint_context: Optional context for fingerprinting.
|
||||
crew: Optional crew instance for hook context.
|
||||
|
||||
Returns:
|
||||
ToolResult containing the execution result and whether it should be
|
||||
treated as a final answer.
|
||||
"""
|
||||
logger = Logger(verbose=crew.verbose if crew else False)
|
||||
tool_name_to_tool_map = {tool.name: tool for tool in tools}
|
||||
|
||||
if agent_key and agent_role and agent:
|
||||
fingerprint_context = fingerprint_context or {}
|
||||
if agent:
|
||||
if hasattr(agent, "set_fingerprint") and callable(agent.set_fingerprint):
|
||||
if isinstance(fingerprint_context, dict):
|
||||
try:
|
||||
fingerprint_obj = Fingerprint.from_dict(fingerprint_context)
|
||||
agent.set_fingerprint(fingerprint=fingerprint_obj)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to set fingerprint: {e}") from e
|
||||
|
||||
tool_usage = ToolUsage(
|
||||
tools_handler=tools_handler,
|
||||
tools=tools,
|
||||
function_calling_llm=function_calling_llm, # type: ignore[arg-type]
|
||||
task=task,
|
||||
agent=agent,
|
||||
action=agent_action,
|
||||
)
|
||||
|
||||
tool_calling = tool_usage.parse_tool_calling(agent_action.text)
|
||||
|
||||
if isinstance(tool_calling, ToolUsageError):
|
||||
return ToolResult(tool_calling.message, False)
|
||||
|
||||
if tool_calling.tool_name.casefold().strip() in [
|
||||
name.casefold().strip() for name in tool_name_to_tool_map
|
||||
] or tool_calling.tool_name.casefold().replace("_", " ") in [
|
||||
name.casefold().strip() for name in tool_name_to_tool_map
|
||||
]:
|
||||
tool = tool_name_to_tool_map.get(tool_calling.tool_name)
|
||||
if not tool:
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([t.name.casefold() for t in tools]),
|
||||
)
|
||||
return ToolResult(result=tool_result, result_as_answer=False)
|
||||
|
||||
tool_input = tool_calling.arguments if tool_calling.arguments else {}
|
||||
hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
)
|
||||
|
||||
before_hooks = get_before_tool_call_hooks()
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
blocked_message = (
|
||||
f"Tool execution blocked by hook. "
|
||||
f"Tool: {tool_calling.tool_name}"
|
||||
)
|
||||
return ToolResult(blocked_message, False)
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in before_tool_call hook: {e}")
|
||||
|
||||
tool_result = await tool_usage.ause(tool_calling, agent_action.text)
|
||||
|
||||
after_hook_context = ToolCallHookContext(
|
||||
tool_name=tool_calling.tool_name,
|
||||
tool_input=tool_input,
|
||||
tool=tool,
|
||||
agent=agent,
|
||||
task=task,
|
||||
crew=crew,
|
||||
tool_result=tool_result,
|
||||
)
|
||||
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result: str = tool_result
|
||||
try:
|
||||
for after_hook in after_hooks:
|
||||
hook_result = after_hook(after_hook_context)
|
||||
if hook_result is not None:
|
||||
modified_result = hook_result
|
||||
after_hook_context.tool_result = modified_result
|
||||
except Exception as e:
|
||||
logger.log("error", f"Error in after_tool_call hook: {e}")
|
||||
|
||||
return ToolResult(modified_result, tool.result_as_answer)
|
||||
|
||||
tool_result = i18n.errors("wrong_tool_name").format(
|
||||
tool=tool_calling.tool_name,
|
||||
tools=", ".join([tool.name.casefold() for tool in tools]),
|
||||
)
|
||||
return ToolResult(result=tool_result, result_as_answer=False)
|
||||
|
||||
|
||||
def execute_tool_and_check_finality(
|
||||
agent_action: AgentAction,
|
||||
tools: list[CrewStructuredTool],
|
||||
@@ -141,10 +273,10 @@ def execute_tool_and_check_finality(
|
||||
|
||||
# Execute after_tool_call hooks
|
||||
after_hooks = get_after_tool_call_hooks()
|
||||
modified_result = tool_result
|
||||
modified_result: str = tool_result
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
hook_result = hook(after_hook_context)
|
||||
for after_hook in after_hooks:
|
||||
hook_result = after_hook(after_hook_context)
|
||||
if hook_result is not None:
|
||||
modified_result = hook_result
|
||||
after_hook_context.tool_result = modified_result
|
||||
|
||||
Reference in New Issue
Block a user