mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-29 18:18:13 +00:00
Lorenze/ensure hooks work with lite agents flows (#3981)
* liteagent support hooks * wip llm.call hooks work - needs tests for this * fix tests * fixed more * more tool hooks test cassettes
This commit is contained in:
@@ -1,6 +1,6 @@
|
||||
from __future__ import annotations
|
||||
|
||||
from typing import TYPE_CHECKING
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from crewai.events.event_listener import event_listener
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
@@ -9,17 +9,22 @@ from crewai.utilities.printer import Printer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from crewai.agents.crew_agent_executor import CrewAgentExecutor
|
||||
from crewai.lite_agent import LiteAgent
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.types import LLMMessage
|
||||
|
||||
|
||||
class LLMCallHookContext:
|
||||
"""Context object passed to LLM call hooks with full executor access.
|
||||
"""Context object passed to LLM call hooks.
|
||||
|
||||
Provides hooks with complete access to the executor state, allowing
|
||||
Provides hooks with complete access to the execution state, allowing
|
||||
modification of messages, responses, and executor attributes.
|
||||
|
||||
Supports both executor-based calls (agents in crews/flows) and direct LLM calls.
|
||||
|
||||
Attributes:
|
||||
executor: Full reference to the CrewAgentExecutor instance
|
||||
messages: Direct reference to executor.messages (mutable list).
|
||||
executor: Reference to the executor (CrewAgentExecutor/LiteAgent) or None for direct calls
|
||||
messages: Direct reference to messages (mutable list).
|
||||
Can be modified in both before_llm_call and after_llm_call hooks.
|
||||
Modifications in after_llm_call hooks persist to the next iteration,
|
||||
allowing hooks to modify conversation history for subsequent LLM calls.
|
||||
@@ -27,33 +32,75 @@ class LLMCallHookContext:
|
||||
Do NOT replace the list (e.g., context.messages = []), as this will break
|
||||
the executor. Use context.messages.append() or context.messages.extend()
|
||||
instead of assignment.
|
||||
agent: Reference to the agent executing the task
|
||||
task: Reference to the task being executed
|
||||
crew: Reference to the crew instance
|
||||
agent: Reference to the agent executing the task (None for direct LLM calls)
|
||||
task: Reference to the task being executed (None for direct LLM calls or LiteAgent)
|
||||
crew: Reference to the crew instance (None for direct LLM calls or LiteAgent)
|
||||
llm: Reference to the LLM instance
|
||||
iterations: Current iteration count
|
||||
iterations: Current iteration count (0 for direct LLM calls)
|
||||
response: LLM response string (only set for after_llm_call hooks).
|
||||
Can be modified by returning a new string from after_llm_call hook.
|
||||
"""
|
||||
|
||||
executor: CrewAgentExecutor | LiteAgent | None
|
||||
messages: list[LLMMessage]
|
||||
agent: Any
|
||||
task: Any
|
||||
crew: Any
|
||||
llm: BaseLLM | None | str | Any
|
||||
iterations: int
|
||||
response: str | None
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
executor: CrewAgentExecutor,
|
||||
executor: CrewAgentExecutor | LiteAgent | None = None,
|
||||
response: str | None = None,
|
||||
messages: list[LLMMessage] | None = None,
|
||||
llm: BaseLLM | str | Any | None = None, # TODO: look into
|
||||
agent: Any | None = None,
|
||||
task: Any | None = None,
|
||||
crew: Any | None = None,
|
||||
) -> None:
|
||||
"""Initialize hook context with executor reference.
|
||||
"""Initialize hook context with executor reference or direct parameters.
|
||||
|
||||
Args:
|
||||
executor: The CrewAgentExecutor instance
|
||||
executor: The CrewAgentExecutor or LiteAgent instance (None for direct LLM calls)
|
||||
response: Optional response string (for after_llm_call hooks)
|
||||
messages: Optional messages list (for direct LLM calls when executor is None)
|
||||
llm: Optional LLM instance (for direct LLM calls when executor is None)
|
||||
agent: Optional agent reference (for direct LLM calls when executor is None)
|
||||
task: Optional task reference (for direct LLM calls when executor is None)
|
||||
crew: Optional crew reference (for direct LLM calls when executor is None)
|
||||
"""
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.agent = executor.agent
|
||||
self.task = executor.task
|
||||
self.crew = executor.crew
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
if executor is not None:
|
||||
# Existing path: extract from executor
|
||||
self.executor = executor
|
||||
self.messages = executor.messages
|
||||
self.llm = executor.llm
|
||||
self.iterations = executor.iterations
|
||||
# Handle CrewAgentExecutor vs LiteAgent differences
|
||||
if hasattr(executor, "agent"):
|
||||
self.agent = executor.agent
|
||||
self.task = cast("CrewAgentExecutor", executor).task
|
||||
self.crew = cast("CrewAgentExecutor", executor).crew
|
||||
else:
|
||||
# LiteAgent case - is the agent itself, doesn't have task/crew
|
||||
self.agent = (
|
||||
executor.original_agent
|
||||
if hasattr(executor, "original_agent")
|
||||
else executor
|
||||
)
|
||||
self.task = None
|
||||
self.crew = None
|
||||
else:
|
||||
# New path: direct LLM call with explicit parameters
|
||||
self.executor = None
|
||||
self.messages = messages or []
|
||||
self.llm = llm
|
||||
self.agent = agent
|
||||
self.task = task
|
||||
self.crew = crew
|
||||
self.iterations = 0
|
||||
|
||||
self.response = response
|
||||
|
||||
def request_human_input(
|
||||
|
||||
@@ -38,6 +38,8 @@ from crewai.events.types.agent_events import (
|
||||
)
|
||||
from crewai.events.types.logging_events import AgentLogsExecutionEvent
|
||||
from crewai.flow.flow_trackable import FlowTrackable
|
||||
from crewai.hooks.llm_hooks import get_after_llm_call_hooks, get_before_llm_call_hooks
|
||||
from crewai.hooks.types import AfterLLMCallHookType, BeforeLLMCallHookType
|
||||
from crewai.lite_agent_output import LiteAgentOutput
|
||||
from crewai.llm import LLM
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
@@ -155,6 +157,12 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
_guardrail: GuardrailCallable | None = PrivateAttr(default=None)
|
||||
_guardrail_retry_count: int = PrivateAttr(default=0)
|
||||
_callbacks: list[TokenCalcHandler] = PrivateAttr(default_factory=list)
|
||||
_before_llm_call_hooks: list[BeforeLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_before_llm_call_hooks
|
||||
)
|
||||
_after_llm_call_hooks: list[AfterLLMCallHookType] = PrivateAttr(
|
||||
default_factory=get_after_llm_call_hooks
|
||||
)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def setup_llm(self) -> Self:
|
||||
@@ -246,6 +254,26 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
"""Return the original role for compatibility with tool interfaces."""
|
||||
return self.role
|
||||
|
||||
@property
|
||||
def before_llm_call_hooks(self) -> list[BeforeLLMCallHookType]:
|
||||
"""Get the before_llm_call hooks for this agent."""
|
||||
return self._before_llm_call_hooks
|
||||
|
||||
@property
|
||||
def after_llm_call_hooks(self) -> list[AfterLLMCallHookType]:
|
||||
"""Get the after_llm_call hooks for this agent."""
|
||||
return self._after_llm_call_hooks
|
||||
|
||||
@property
|
||||
def messages(self) -> list[LLMMessage]:
|
||||
"""Get the messages list for hook context compatibility."""
|
||||
return self._messages
|
||||
|
||||
@property
|
||||
def iterations(self) -> int:
|
||||
"""Get the current iteration count for hook context compatibility."""
|
||||
return self._iterations
|
||||
|
||||
def kickoff(
|
||||
self,
|
||||
messages: str | list[LLMMessage],
|
||||
@@ -504,7 +532,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
AgentFinish: The final result of the agent execution.
|
||||
"""
|
||||
# Execute the agent loop
|
||||
formatted_answer = None
|
||||
formatted_answer: AgentAction | AgentFinish | None = None
|
||||
while not isinstance(formatted_answer, AgentFinish):
|
||||
try:
|
||||
if has_reached_max_iterations(self._iterations, self.max_iterations):
|
||||
@@ -526,6 +554,7 @@ class LiteAgent(FlowTrackable, BaseModel):
|
||||
callbacks=self._callbacks,
|
||||
printer=self._printer,
|
||||
from_agent=self,
|
||||
executor_context=self,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -1642,6 +1642,10 @@ class LLM(BaseLLM):
|
||||
if message.get("role") == "system":
|
||||
msg_role: Literal["assistant"] = "assistant"
|
||||
message["role"] = msg_role
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# --- 5) Set up callbacks if provided
|
||||
with suppress_warnings():
|
||||
if callbacks and len(callbacks) > 0:
|
||||
@@ -1651,7 +1655,16 @@ class LLM(BaseLLM):
|
||||
params = self._prepare_completion_params(messages, tools)
|
||||
# --- 7) Make the completion call and handle response
|
||||
if self.stream:
|
||||
return self._handle_streaming_response(
|
||||
result = self._handle_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
else:
|
||||
result = self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
@@ -1660,14 +1673,12 @@ class LLM(BaseLLM):
|
||||
response_model=response_model,
|
||||
)
|
||||
|
||||
return self._handle_non_streaming_response(
|
||||
params=params,
|
||||
callbacks=callbacks,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
response_model=response_model,
|
||||
)
|
||||
if isinstance(result, str):
|
||||
result = self._invoke_after_llm_call_hooks(
|
||||
messages, result, from_agent
|
||||
)
|
||||
|
||||
return result
|
||||
except LLMContextLengthExceededError:
|
||||
# Re-raise LLMContextLengthExceededError as it should be handled
|
||||
# by the CrewAgentExecutor._invoke_loop method, which can then decide
|
||||
|
||||
@@ -314,7 +314,7 @@ class BaseLLM(ABC):
|
||||
call_type: LLMCallType,
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | None = None,
|
||||
messages: str | list[dict[str, Any]] | None = None,
|
||||
messages: str | list[LLMMessage] | None = None,
|
||||
) -> None:
|
||||
"""Emit LLM call completed event."""
|
||||
crewai_event_bus.emit(
|
||||
@@ -586,3 +586,134 @@ class BaseLLM(ABC):
|
||||
Dictionary with token usage totals
|
||||
"""
|
||||
return UsageMetrics(**self._token_usage)
|
||||
|
||||
def _invoke_before_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
from_agent: Agent | None = None,
|
||||
) -> bool:
|
||||
"""Invoke before_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations before
|
||||
making the actual LLM call when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages being sent to the LLM
|
||||
from_agent: The agent making the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
True if LLM call should proceed, False if blocked by hook
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and not self._invoke_before_llm_call_hooks(
|
||||
... messages, from_agent
|
||||
... ):
|
||||
... raise ValueError("LLM call blocked by hook")
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None:
|
||||
return True
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_before_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
before_hooks = get_before_llm_call_hooks()
|
||||
if not before_hooks:
|
||||
return True
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
)
|
||||
printer = Printer()
|
||||
|
||||
try:
|
||||
for hook in before_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is False:
|
||||
printer.print(
|
||||
content="LLM call blocked by before_llm_call hook",
|
||||
color="yellow",
|
||||
)
|
||||
return False
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in before_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return True
|
||||
|
||||
def _invoke_after_llm_call_hooks(
|
||||
self,
|
||||
messages: list[LLMMessage],
|
||||
response: str,
|
||||
from_agent: Agent | None = None,
|
||||
) -> str:
|
||||
"""Invoke after_llm_call hooks for direct LLM calls (no agent context).
|
||||
|
||||
This method should be called by native provider implementations after
|
||||
receiving the LLM response when from_agent is None (direct calls).
|
||||
|
||||
Args:
|
||||
messages: The messages that were sent to the LLM
|
||||
response: The response from the LLM
|
||||
from_agent: The agent that made the call (None for direct calls)
|
||||
|
||||
Returns:
|
||||
The potentially modified response string
|
||||
|
||||
Example:
|
||||
>>> # In a native provider's call() method:
|
||||
>>> if from_agent is None and isinstance(result, str):
|
||||
... result = self._invoke_after_llm_call_hooks(
|
||||
... messages, result, from_agent
|
||||
... )
|
||||
"""
|
||||
# Only invoke hooks for direct calls (no agent context)
|
||||
if from_agent is not None or not isinstance(response, str):
|
||||
return response
|
||||
|
||||
from crewai.hooks.llm_hooks import (
|
||||
LLMCallHookContext,
|
||||
get_after_llm_call_hooks,
|
||||
)
|
||||
from crewai.utilities.printer import Printer
|
||||
|
||||
after_hooks = get_after_llm_call_hooks()
|
||||
if not after_hooks:
|
||||
return response
|
||||
|
||||
hook_context = LLMCallHookContext(
|
||||
executor=None,
|
||||
messages=messages,
|
||||
llm=self,
|
||||
agent=None,
|
||||
task=None,
|
||||
crew=None,
|
||||
response=response,
|
||||
)
|
||||
printer = Printer()
|
||||
modified_response = response
|
||||
|
||||
try:
|
||||
for hook in after_hooks:
|
||||
result = hook(hook_context)
|
||||
if result is not None and isinstance(result, str):
|
||||
modified_response = result
|
||||
hook_context.response = modified_response
|
||||
except Exception as e:
|
||||
printer.print(
|
||||
content=f"Error in after_llm_call hook: {e}",
|
||||
color="yellow",
|
||||
)
|
||||
|
||||
return modified_response
|
||||
|
||||
@@ -187,6 +187,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, system_message, tools
|
||||
@@ -494,7 +497,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API usage: {usage}")
|
||||
|
||||
return content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
@@ -588,7 +593,9 @@ class AnthropicCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
|
||||
@@ -216,6 +216,9 @@ class AzureCompletion(BaseLLM):
|
||||
# Format messages for Azure
|
||||
formatted_messages = self._format_messages_for_azure(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare completion parameters
|
||||
completion_params = self._prepare_completion_params(
|
||||
formatted_messages, tools, response_model
|
||||
@@ -550,6 +553,10 @@ class AzureCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
@@ -642,7 +649,9 @@ class AzureCompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
|
||||
@@ -312,9 +312,14 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
# Format messages for Converse API
|
||||
formatted_messages, system_message = self._format_messages_for_converse(
|
||||
messages # type: ignore[arg-type]
|
||||
messages
|
||||
)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(
|
||||
cast(list[LLMMessage], formatted_messages), from_agent
|
||||
):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
# Prepare request body
|
||||
body: BedrockConverseRequestBody = {
|
||||
"inferenceConfig": self._get_inference_config(),
|
||||
@@ -356,11 +361,19 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
if self.stream:
|
||||
return self._handle_streaming_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
return self._handle_converse(
|
||||
formatted_messages, body, available_functions, from_task, from_agent
|
||||
cast(list[LLMMessage], formatted_messages),
|
||||
body,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
@@ -481,7 +494,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: Mapping[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -605,7 +618,11 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return text_content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
text_content,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
except ClientError as e:
|
||||
# Handle all AWS ClientError exceptions as per documentation
|
||||
@@ -662,7 +679,7 @@ class BedrockCompletion(BaseLLM):
|
||||
|
||||
def _handle_streaming_converse(
|
||||
self,
|
||||
messages: list[dict[str, Any]],
|
||||
messages: list[LLMMessage],
|
||||
body: BedrockConverseRequestBody,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
@@ -1149,16 +1166,25 @@ class BedrockCompletion(BaseLLM):
|
||||
messages=messages,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages,
|
||||
full_response,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
def _format_messages_for_converse(
|
||||
self, messages: str | list[dict[str, str]]
|
||||
self, messages: str | list[LLMMessage]
|
||||
) -> tuple[list[dict[str, Any]], str | None]:
|
||||
"""Format messages for Converse API following AWS documentation."""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages) # type: ignore[arg-type]
|
||||
"""Format messages for Converse API following AWS documentation.
|
||||
|
||||
converse_messages = []
|
||||
Note: Returns dict[str, Any] instead of LLMMessage because Bedrock uses
|
||||
a different content structure: {"role": str, "content": [{"text": str}]}
|
||||
rather than the standard {"role": str, "content": str}.
|
||||
"""
|
||||
# Use base class formatting first
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
converse_messages: list[dict[str, Any]] = []
|
||||
system_message: str | None = None
|
||||
|
||||
for message in formatted_messages:
|
||||
|
||||
@@ -246,6 +246,11 @@ class GeminiCompletion(BaseLLM):
|
||||
messages
|
||||
)
|
||||
|
||||
messages_for_hooks = self._convert_contents_to_dict(formatted_content)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(messages_for_hooks, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
config = self._prepare_generation_config(
|
||||
system_instruction, tools, response_model
|
||||
)
|
||||
@@ -559,7 +564,9 @@ class GeminiCompletion(BaseLLM):
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, content, from_agent
|
||||
)
|
||||
|
||||
def _handle_streaming_completion(
|
||||
self,
|
||||
@@ -639,7 +646,9 @@ class GeminiCompletion(BaseLLM):
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
@@ -787,7 +796,159 @@ class GeminiCompletion(BaseLLM):
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
system_instruction: str | None,
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str | Any:
|
||||
"""Handle async non-streaming content generation."""
|
||||
try:
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
response = await self.client.aio.models.generate_content(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
|
||||
usage = self._extract_token_usage(response)
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
raise e from e
|
||||
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
if response.candidates and (self.tools or available_functions):
|
||||
candidate = response.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
function_name = part.function_call.name
|
||||
if function_name is None:
|
||||
continue
|
||||
function_args = (
|
||||
dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {}
|
||||
)
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions or {},
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
content = response.text or ""
|
||||
content = self._apply_stop_words(content)
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return content
|
||||
|
||||
async def _ahandle_streaming_completion(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
config: types.GenerateContentConfig,
|
||||
available_functions: dict[str, Any] | None = None,
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
) -> str:
|
||||
"""Handle async streaming content generation."""
|
||||
full_response = ""
|
||||
function_calls: dict[str, dict[str, Any]] = {}
|
||||
|
||||
# The API accepts list[Content] but mypy is overly strict about variance
|
||||
contents_for_api: Any = contents
|
||||
stream = await self.client.aio.models.generate_content_stream(
|
||||
model=self.model,
|
||||
contents=contents_for_api,
|
||||
config=config,
|
||||
)
|
||||
async for chunk in stream:
|
||||
if chunk.text:
|
||||
full_response += chunk.text
|
||||
self._emit_stream_chunk_event(
|
||||
chunk=chunk.text,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if chunk.candidates:
|
||||
candidate = chunk.candidates[0]
|
||||
if candidate.content and candidate.content.parts:
|
||||
for part in candidate.content.parts:
|
||||
if hasattr(part, "function_call") and part.function_call:
|
||||
call_id = part.function_call.name or "default"
|
||||
if call_id not in function_calls:
|
||||
function_calls[call_id] = {
|
||||
"name": part.function_call.name,
|
||||
"args": dict(part.function_call.args)
|
||||
if part.function_call.args
|
||||
else {},
|
||||
}
|
||||
|
||||
if function_calls and available_functions:
|
||||
for call_data in function_calls.values():
|
||||
function_name = call_data["name"]
|
||||
function_args = call_data["args"]
|
||||
|
||||
# Skip if function_name is None
|
||||
if not isinstance(function_name, str):
|
||||
continue
|
||||
|
||||
# Ensure function_args is a dict
|
||||
if not isinstance(function_args, dict):
|
||||
function_args = {}
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
messages_for_event = self._convert_contents_to_dict(contents)
|
||||
|
||||
self._emit_call_completed_event(
|
||||
response=full_response,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=messages_for_event,
|
||||
)
|
||||
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
messages_for_event, full_response, from_agent
|
||||
)
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
@@ -851,7 +1012,7 @@ class GeminiCompletion(BaseLLM):
|
||||
def _convert_contents_to_dict(
|
||||
self,
|
||||
contents: list[types.Content],
|
||||
) -> list[dict[str, str]]:
|
||||
) -> list[LLMMessage]:
|
||||
"""Convert contents to dict format."""
|
||||
result: list[dict[str, str]] = []
|
||||
for content_obj in contents:
|
||||
|
||||
@@ -190,6 +190,9 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
formatted_messages = self._format_messages(messages)
|
||||
|
||||
if not self._invoke_before_llm_call_hooks(formatted_messages, from_agent):
|
||||
raise ValueError("LLM call blocked by before_llm_call hook")
|
||||
|
||||
completion_params = self._prepare_completion_params(
|
||||
messages=formatted_messages, tools=tools
|
||||
)
|
||||
@@ -474,6 +477,10 @@ class OpenAICompletion(BaseLLM):
|
||||
|
||||
if usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"OpenAI API usage: {usage}")
|
||||
|
||||
content = self._invoke_after_llm_call_hooks(
|
||||
params["messages"], content, from_agent
|
||||
)
|
||||
except NotFoundError as e:
|
||||
error_msg = f"Model {self.model} not found: {e}"
|
||||
logging.error(error_msg)
|
||||
@@ -629,7 +636,9 @@ class OpenAICompletion(BaseLLM):
|
||||
messages=params["messages"],
|
||||
)
|
||||
|
||||
return full_response
|
||||
return self._invoke_after_llm_call_hooks(
|
||||
params["messages"], full_response, from_agent
|
||||
)
|
||||
|
||||
async def _ahandle_completion(
|
||||
self,
|
||||
|
||||
@@ -237,7 +237,7 @@ def get_llm_response(
|
||||
from_task: Task | None = None,
|
||||
from_agent: Agent | LiteAgent | None = None,
|
||||
response_model: type[BaseModel] | None = None,
|
||||
executor_context: CrewAgentExecutor | None = None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None = None,
|
||||
) -> str:
|
||||
"""Call the LLM and return the response, handling any invalid responses.
|
||||
|
||||
@@ -727,7 +727,7 @@ def load_agent_from_repository(from_repository: str) -> dict[str, Any]:
|
||||
|
||||
|
||||
def _setup_before_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None, printer: Printer
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None, printer: Printer
|
||||
) -> bool:
|
||||
"""Setup and invoke before_llm_call hooks for the executor context.
|
||||
|
||||
@@ -777,7 +777,7 @@ def _setup_before_llm_call_hooks(
|
||||
|
||||
|
||||
def _setup_after_llm_call_hooks(
|
||||
executor_context: CrewAgentExecutor | None,
|
||||
executor_context: CrewAgentExecutor | LiteAgent | None,
|
||||
answer: str,
|
||||
printer: Printer,
|
||||
) -> str:
|
||||
|
||||
Reference in New Issue
Block a user