mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-05-01 15:22:37 +00:00
feat: enhance AnthropicCompletion class with additional client parame… (#3707)
* feat: enhance AnthropicCompletion class with additional client parameters and tool handling - Added support for client_params in the AnthropicCompletion class to allow for additional client configuration. - Refactored client initialization to use a dedicated method for retrieving client parameters. - Implemented a new method to handle tool use conversation flow, ensuring proper execution and response handling. - Introduced comprehensive test cases to validate the functionality of the AnthropicCompletion class, including tool use scenarios and parameter handling. * drop print statements * test: add fixture to mock ANTHROPIC_API_KEY for tests - Introduced a pytest fixture to automatically mock the ANTHROPIC_API_KEY environment variable for all tests in the test_anthropic.py module. - This change ensures that tests can run without requiring a real API key, improving test isolation and reliability. * refactor: streamline streaming message handling in AnthropicCompletion class - Removed the 'stream' parameter from the API call as it is set internally by the SDK. - Simplified the handling of tool use events and response construction by extracting token usage from the final message. - Enhanced the flow for managing tool use conversation, ensuring proper integration with the streaming API response. * fix streaming here too * fix: improve error handling in tool conversion for AnthropicCompletion class - Enhanced exception handling during tool conversion by catching KeyError and ValueError. - Added logging for conversion errors to aid in debugging and maintain robustness in tool integration.
This commit is contained in:
@@ -386,7 +386,7 @@ class EventListener(BaseEventListener):
|
||||
|
||||
# Read from the in-memory stream
|
||||
content = self.text_stream.read()
|
||||
_printer.print(content, end="", flush=True)
|
||||
_printer.print(content)
|
||||
self.next_chunk = self.text_stream.tell()
|
||||
|
||||
# ----------- LLM GUARDRAIL EVENTS -----------
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -40,6 +39,7 @@ class AnthropicCompletion(BaseLLM):
|
||||
top_p: float | None = None,
|
||||
stop_sequences: list[str] | None = None,
|
||||
stream: bool = False,
|
||||
client_params: dict[str, Any] | None = None,
|
||||
**kwargs,
|
||||
):
|
||||
"""Initialize Anthropic chat completion client.
|
||||
@@ -55,19 +55,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
top_p: Nucleus sampling parameter
|
||||
stop_sequences: Stop sequences (Anthropic uses stop_sequences, not stop)
|
||||
stream: Enable streaming responses
|
||||
client_params: Additional parameters for the Anthropic client
|
||||
**kwargs: Additional parameters
|
||||
"""
|
||||
super().__init__(
|
||||
model=model, temperature=temperature, stop=stop_sequences or [], **kwargs
|
||||
)
|
||||
|
||||
# Initialize Anthropic client
|
||||
self.client = Anthropic(
|
||||
api_key=api_key or os.getenv("ANTHROPIC_API_KEY"),
|
||||
base_url=base_url,
|
||||
timeout=timeout,
|
||||
max_retries=max_retries,
|
||||
)
|
||||
# Client params
|
||||
self.client_params = client_params
|
||||
self.base_url = base_url
|
||||
self.timeout = timeout
|
||||
self.max_retries = max_retries
|
||||
|
||||
self.client = Anthropic(**self._get_client_params())
|
||||
|
||||
# Store completion parameters
|
||||
self.max_tokens = max_tokens
|
||||
@@ -79,6 +80,26 @@ class AnthropicCompletion(BaseLLM):
|
||||
self.is_claude_3 = "claude-3" in model.lower()
|
||||
self.supports_tools = self.is_claude_3 # Claude 3+ supports tool use
|
||||
|
||||
def _get_client_params(self) -> dict[str, Any]:
|
||||
"""Get client parameters."""
|
||||
|
||||
if self.api_key is None:
|
||||
self.api_key = os.getenv("ANTHROPIC_API_KEY")
|
||||
if self.api_key is None:
|
||||
raise ValueError("ANTHROPIC_API_KEY is required")
|
||||
|
||||
client_params = {
|
||||
"api_key": self.api_key,
|
||||
"base_url": self.base_url,
|
||||
"timeout": self.timeout,
|
||||
"max_retries": self.max_retries,
|
||||
}
|
||||
|
||||
if self.client_params:
|
||||
client_params.update(self.client_params)
|
||||
|
||||
return client_params
|
||||
|
||||
def call(
|
||||
self,
|
||||
messages: str | list[dict[str, str]],
|
||||
@@ -183,12 +204,20 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
def _convert_tools_for_interference(self, tools: list[dict]) -> list[dict]:
|
||||
"""Convert CrewAI tool format to Anthropic tool use format."""
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
anthropic_tools = []
|
||||
|
||||
for tool in tools:
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
if "input_schema" in tool and "name" in tool and "description" in tool:
|
||||
anthropic_tools.append(tool)
|
||||
continue
|
||||
|
||||
try:
|
||||
from crewai.llms.providers.utils.common import safe_tool_conversion
|
||||
|
||||
name, description, parameters = safe_tool_conversion(tool, "Anthropic")
|
||||
except (ImportError, KeyError, ValueError) as e:
|
||||
logging.error(f"Error converting tool to Anthropic format: {e}")
|
||||
raise e
|
||||
|
||||
anthropic_tool = {
|
||||
"name": name,
|
||||
@@ -196,7 +225,13 @@ class AnthropicCompletion(BaseLLM):
|
||||
}
|
||||
|
||||
if parameters and isinstance(parameters, dict):
|
||||
anthropic_tool["input_schema"] = parameters # type: ignore
|
||||
anthropic_tool["input_schema"] = parameters
|
||||
else:
|
||||
anthropic_tool["input_schema"] = {
|
||||
"type": "object",
|
||||
"properties": {},
|
||||
"required": [],
|
||||
}
|
||||
|
||||
anthropic_tools.append(anthropic_tool)
|
||||
|
||||
@@ -229,13 +264,11 @@ class AnthropicCompletion(BaseLLM):
|
||||
content = message.get("content", "")
|
||||
|
||||
if role == "system":
|
||||
# Extract system message - Anthropic handles it separately
|
||||
if system_message:
|
||||
system_message += f"\n\n{content}"
|
||||
else:
|
||||
system_message = content
|
||||
else:
|
||||
# Add user/assistant messages - ensure both role and content are str, not None
|
||||
role_str = role if role is not None else "user"
|
||||
content_str = content if content is not None else ""
|
||||
formatted_messages.append({"role": role_str, "content": content_str})
|
||||
@@ -270,22 +303,22 @@ class AnthropicCompletion(BaseLLM):
|
||||
usage = self._extract_anthropic_token_usage(response)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# Check if Claude wants to use tools
|
||||
if response.content and available_functions:
|
||||
for content_block in response.content:
|
||||
if isinstance(content_block, ToolUseBlock):
|
||||
function_name = content_block.name
|
||||
function_args = content_block.input
|
||||
tool_uses = [
|
||||
block for block in response.content if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
if tool_uses:
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
response,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
# Extract text content
|
||||
content = ""
|
||||
@@ -318,12 +351,14 @@ class AnthropicCompletion(BaseLLM):
|
||||
) -> str:
|
||||
"""Handle streaming message completion."""
|
||||
full_response = ""
|
||||
tool_uses = {}
|
||||
|
||||
# Remove 'stream' parameter as messages.stream() doesn't accept it
|
||||
# (the SDK sets it internally)
|
||||
stream_params = {k: v for k, v in params.items() if k != "stream"}
|
||||
|
||||
# Make streaming API call
|
||||
with self.client.messages.stream(**params) as stream:
|
||||
with self.client.messages.stream(**stream_params) as stream:
|
||||
for event in stream:
|
||||
# Handle content delta events
|
||||
if hasattr(event, "delta") and hasattr(event.delta, "text"):
|
||||
text_delta = event.delta.text
|
||||
full_response += text_delta
|
||||
@@ -333,44 +368,29 @@ class AnthropicCompletion(BaseLLM):
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Handle tool use events
|
||||
elif hasattr(event, "delta") and hasattr(event.delta, "partial_json"):
|
||||
# Tool use streaming - accumulate JSON
|
||||
tool_id = getattr(event, "index", "default")
|
||||
if tool_id not in tool_uses:
|
||||
tool_uses[tool_id] = {
|
||||
"name": "",
|
||||
"input": "",
|
||||
}
|
||||
final_message: Message = stream.get_final_message()
|
||||
|
||||
if hasattr(event.delta, "name"):
|
||||
tool_uses[tool_id]["name"] = event.delta.name
|
||||
if hasattr(event.delta, "partial_json"):
|
||||
tool_uses[tool_id]["input"] += event.delta.partial_json
|
||||
usage = self._extract_anthropic_token_usage(final_message)
|
||||
self._track_token_usage_internal(usage)
|
||||
|
||||
# Handle completed tool uses
|
||||
if tool_uses and available_functions:
|
||||
for tool_data in tool_uses.values():
|
||||
function_name = tool_data["name"]
|
||||
if final_message.content and available_functions:
|
||||
tool_uses = [
|
||||
block
|
||||
for block in final_message.content
|
||||
if isinstance(block, ToolUseBlock)
|
||||
]
|
||||
|
||||
try:
|
||||
function_args = json.loads(tool_data["input"])
|
||||
except json.JSONDecodeError as e:
|
||||
logging.error(f"Failed to parse streamed tool arguments: {e}")
|
||||
continue
|
||||
|
||||
# Execute tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args,
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
if tool_uses:
|
||||
# Handle tool use conversation flow
|
||||
return self._handle_tool_use_conversation(
|
||||
final_message,
|
||||
tool_uses,
|
||||
params,
|
||||
available_functions,
|
||||
from_task,
|
||||
from_agent,
|
||||
)
|
||||
|
||||
if result is not None:
|
||||
return result
|
||||
|
||||
# Apply stop words to full response
|
||||
full_response = self._apply_stop_words(full_response)
|
||||
|
||||
@@ -385,6 +405,113 @@ class AnthropicCompletion(BaseLLM):
|
||||
|
||||
return full_response
|
||||
|
||||
def _handle_tool_use_conversation(
|
||||
self,
|
||||
initial_response: Message,
|
||||
tool_uses: list[ToolUseBlock],
|
||||
params: dict[str, Any],
|
||||
available_functions: dict[str, Any],
|
||||
from_task: Any | None = None,
|
||||
from_agent: Any | None = None,
|
||||
) -> str:
|
||||
"""Handle the complete tool use conversation flow.
|
||||
|
||||
This implements the proper Anthropic tool use pattern:
|
||||
1. Claude requests tool use
|
||||
2. We execute the tools
|
||||
3. We send tool results back to Claude
|
||||
4. Claude processes results and generates final response
|
||||
"""
|
||||
# Execute all requested tools and collect results
|
||||
tool_results = []
|
||||
|
||||
for tool_use in tool_uses:
|
||||
function_name = tool_use.name
|
||||
function_args = tool_use.input
|
||||
|
||||
# Execute the tool
|
||||
result = self._handle_tool_execution(
|
||||
function_name=function_name,
|
||||
function_args=function_args, # type: ignore
|
||||
available_functions=available_functions,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
)
|
||||
|
||||
# Create tool result in Anthropic format
|
||||
tool_result = {
|
||||
"type": "tool_result",
|
||||
"tool_use_id": tool_use.id,
|
||||
"content": str(result)
|
||||
if result is not None
|
||||
else "Tool execution completed",
|
||||
}
|
||||
tool_results.append(tool_result)
|
||||
|
||||
# Prepare follow-up conversation with tool results
|
||||
follow_up_params = params.copy()
|
||||
|
||||
# Add Claude's tool use response to conversation
|
||||
assistant_message = {"role": "assistant", "content": initial_response.content}
|
||||
|
||||
# Add user message with tool results
|
||||
user_message = {"role": "user", "content": tool_results}
|
||||
|
||||
# Update messages for follow-up call
|
||||
follow_up_params["messages"] = params["messages"] + [
|
||||
assistant_message,
|
||||
user_message,
|
||||
]
|
||||
|
||||
try:
|
||||
# Send tool results back to Claude for final response
|
||||
final_response: Message = self.client.messages.create(**follow_up_params)
|
||||
|
||||
# Track token usage for follow-up call
|
||||
follow_up_usage = self._extract_anthropic_token_usage(final_response)
|
||||
self._track_token_usage_internal(follow_up_usage)
|
||||
|
||||
# Extract final text content
|
||||
final_content = ""
|
||||
if final_response.content:
|
||||
for content_block in final_response.content:
|
||||
if hasattr(content_block, "text"):
|
||||
final_content += content_block.text
|
||||
|
||||
final_content = self._apply_stop_words(final_content)
|
||||
|
||||
# Emit completion event for the final response
|
||||
self._emit_call_completed_event(
|
||||
response=final_content,
|
||||
call_type=LLMCallType.LLM_CALL,
|
||||
from_task=from_task,
|
||||
from_agent=from_agent,
|
||||
messages=follow_up_params["messages"],
|
||||
)
|
||||
|
||||
# Log combined token usage
|
||||
total_usage = {
|
||||
"input_tokens": follow_up_usage.get("input_tokens", 0),
|
||||
"output_tokens": follow_up_usage.get("output_tokens", 0),
|
||||
"total_tokens": follow_up_usage.get("total_tokens", 0),
|
||||
}
|
||||
|
||||
if total_usage.get("total_tokens", 0) > 0:
|
||||
logging.info(f"Anthropic API tool conversation usage: {total_usage}")
|
||||
|
||||
return final_content
|
||||
|
||||
except Exception as e:
|
||||
if is_context_length_exceeded(e):
|
||||
logging.error(f"Context window exceeded in tool follow-up: {e}")
|
||||
raise LLMContextLengthExceededError(str(e)) from e
|
||||
|
||||
logging.error(f"Tool follow-up conversation failed: {e}")
|
||||
# Fallback: return the first tool result if follow-up fails
|
||||
if tool_results:
|
||||
return tool_results[0]["content"]
|
||||
raise e
|
||||
|
||||
def supports_function_calling(self) -> bool:
|
||||
"""Check if the model supports function calling."""
|
||||
return self.supports_tools
|
||||
|
||||
Reference in New Issue
Block a user