mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 15:48:29 +00:00
Compare commits
2 Commits
devin/1739
...
devin/1739
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
f519b0d31c | ||
|
|
be11f9c036 |
@@ -314,6 +314,46 @@ class BaseAgent(ABC, BaseModel):
|
||||
|
||||
return copied_agent
|
||||
|
||||
def _interpolate_only(self, input_string: str, inputs: Dict[str, Any]) -> str:
|
||||
"""Interpolate placeholders in a string while preserving JSON-like structures.
|
||||
|
||||
Args:
|
||||
input_string (str): The string containing placeholders to interpolate.
|
||||
inputs (Dict[str, Any]): Dictionary of values for interpolation.
|
||||
|
||||
Returns:
|
||||
str: The interpolated string with JSON structures preserved.
|
||||
|
||||
Example:
|
||||
>>> _interpolate_only("Name: {name}, Config: {'key': 'value'}", {"name": "John"})
|
||||
"Name: John, Config: {'key': 'value'}"
|
||||
|
||||
Raises:
|
||||
ValueError: If input_string is None or empty, or if inputs is empty
|
||||
KeyError: If a required template variable is missing from inputs
|
||||
"""
|
||||
if not input_string:
|
||||
raise ValueError("Input string cannot be None or empty")
|
||||
if not inputs:
|
||||
raise ValueError("Inputs dictionary cannot be empty")
|
||||
|
||||
try:
|
||||
# First check if all required variables are present
|
||||
required_vars = [
|
||||
var.split("}")[0] for var in input_string.split("{")[1:]
|
||||
if "}" in var
|
||||
]
|
||||
for var in required_vars:
|
||||
if var not in inputs:
|
||||
raise KeyError(f"Missing required template variable: {var}")
|
||||
|
||||
escaped_string = input_string.replace("{", "{{").replace("}", "}}")
|
||||
for key in inputs.keys():
|
||||
escaped_string = escaped_string.replace(f"{{{{{key}}}}}", f"{{{key}}}")
|
||||
return escaped_string.format(**inputs)
|
||||
except ValueError as e:
|
||||
raise ValueError(f"Error during string interpolation: {str(e)}") from e
|
||||
|
||||
def interpolate_inputs(self, inputs: Dict[str, Any]) -> None:
|
||||
"""Interpolate inputs into the agent description and backstory."""
|
||||
if self._original_role is None:
|
||||
@@ -324,9 +364,9 @@ class BaseAgent(ABC, BaseModel):
|
||||
self._original_backstory = self.backstory
|
||||
|
||||
if inputs:
|
||||
self.role = self._original_role.format(**inputs)
|
||||
self.goal = self._original_goal.format(**inputs)
|
||||
self.backstory = self._original_backstory.format(**inputs)
|
||||
self.role = self._interpolate_only(self._original_role, inputs)
|
||||
self.goal = self._interpolate_only(self._original_goal, inputs)
|
||||
self.backstory = self._interpolate_only(self._original_backstory, inputs)
|
||||
|
||||
def set_cache_handler(self, cache_handler: CacheHandler) -> None:
|
||||
"""Set the cache handler for the agent.
|
||||
|
||||
@@ -92,19 +92,6 @@ LLM_CONTEXT_WINDOW_SIZES = {
|
||||
"Meta-Llama-3.2-1B-Instruct": 16384,
|
||||
}
|
||||
|
||||
# Common Vertex AI regions
|
||||
VERTEX_AI_REGIONS = [
|
||||
"us-central1", # Iowa
|
||||
"us-east1", # South Carolina
|
||||
"us-west1", # Oregon
|
||||
"europe-west1", # Belgium
|
||||
"europe-west2", # London
|
||||
"europe-west3", # Frankfurt
|
||||
"europe-west4", # Netherlands
|
||||
"asia-east1", # Taiwan
|
||||
"asia-southeast1" # Singapore
|
||||
]
|
||||
|
||||
DEFAULT_CONTEXT_WINDOW_SIZE = 8192
|
||||
CONTEXT_WINDOW_USAGE_RATIO = 0.75
|
||||
|
||||
@@ -130,20 +117,9 @@ def suppress_warnings():
|
||||
|
||||
|
||||
class LLM:
|
||||
"""A wrapper around LiteLLM providing a unified interface for various LLM providers.
|
||||
|
||||
Args:
|
||||
model (str): The identifier of the LLM model to use
|
||||
location (Optional[str]): The GCP region for Vertex AI models (e.g., 'us-central1', 'europe-west4').
|
||||
Only applicable for Vertex AI models.
|
||||
timeout (Optional[Union[float, int]]): Maximum time to wait for the model response
|
||||
temperature (Optional[float]): Controls randomness in the model's output
|
||||
top_p (Optional[float]): Controls diversity of the model's output
|
||||
"""
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
location: Optional[str] = None,
|
||||
timeout: Optional[Union[float, int]] = None,
|
||||
temperature: Optional[float] = None,
|
||||
top_p: Optional[float] = None,
|
||||
@@ -167,18 +143,6 @@ class LLM:
|
||||
**kwargs,
|
||||
):
|
||||
self.model = model
|
||||
|
||||
# Validate location parameter
|
||||
if location is not None:
|
||||
if not isinstance(location, str):
|
||||
raise ValueError("Location must be a string when provided")
|
||||
if self._is_vertex_model(model) and location not in VERTEX_AI_REGIONS:
|
||||
raise ValueError(
|
||||
f"Invalid Vertex AI region: {location}. "
|
||||
f"Supported regions: {', '.join(VERTEX_AI_REGIONS)}"
|
||||
)
|
||||
self.location = location
|
||||
|
||||
self.timeout = timeout
|
||||
self.temperature = temperature
|
||||
self.top_p = top_p
|
||||
@@ -202,10 +166,6 @@ class LLM:
|
||||
self.additional_params = kwargs
|
||||
self.is_anthropic = self._is_anthropic_model(model)
|
||||
|
||||
# Set vertex location if provided for vertex models
|
||||
if self.location and self._is_vertex_model(model):
|
||||
litellm.vertex_location = self.location
|
||||
|
||||
litellm.drop_params = True
|
||||
|
||||
# Normalize self.stop to always be a List[str]
|
||||
@@ -219,17 +179,6 @@ class LLM:
|
||||
self.set_callbacks(callbacks)
|
||||
self.set_env_callbacks()
|
||||
|
||||
def _is_vertex_model(self, model: str) -> bool:
|
||||
"""Determine if the model is from Vertex AI provider.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Vertex AI, False otherwise.
|
||||
"""
|
||||
return "vertex" in model.lower() or model.startswith("gemini-")
|
||||
|
||||
def _is_anthropic_model(self, model: str) -> bool:
|
||||
"""Determine if the model is from Anthropic provider.
|
||||
|
||||
|
||||
@@ -1357,6 +1357,51 @@ def test_handle_context_length_exceeds_limit_cli_no():
|
||||
mock_handle_context.assert_not_called()
|
||||
|
||||
|
||||
def test_interpolate_inputs_with_tool_description():
|
||||
from crewai.tools import BaseTool
|
||||
|
||||
class DummyTool(BaseTool):
|
||||
name: str = "dummy_tool"
|
||||
description: str = "Tool Arguments: {'arg': {'description': 'test arg', 'type': 'str'}}"
|
||||
|
||||
def _run(self, arg: str) -> str:
|
||||
"""Run the tool."""
|
||||
return f"Dummy result for: {arg}"
|
||||
|
||||
tool = DummyTool()
|
||||
agent = Agent(
|
||||
role="{topic} specialist",
|
||||
goal="Figure {goal} out",
|
||||
backstory="I am the master of {role}\nTools: {tool_desc}",
|
||||
)
|
||||
|
||||
agent.interpolate_inputs({
|
||||
"topic": "AI",
|
||||
"goal": "life",
|
||||
"role": "all things",
|
||||
"tool_desc": tool.description
|
||||
})
|
||||
assert "Tool Arguments: {'arg': {'description': 'test arg', 'type': 'str'}}" in agent.backstory
|
||||
|
||||
def test_interpolate_only_error_handling():
|
||||
agent = Agent(
|
||||
role="{topic} specialist",
|
||||
goal="Figure {goal} out",
|
||||
backstory="I am the master of {role}",
|
||||
)
|
||||
|
||||
# Test empty input string
|
||||
with pytest.raises(ValueError, match="Input string cannot be None or empty"):
|
||||
agent._interpolate_only("", {"topic": "AI"})
|
||||
|
||||
# Test empty inputs dictionary
|
||||
with pytest.raises(ValueError, match="Inputs dictionary cannot be empty"):
|
||||
agent._interpolate_only("test {topic}", {})
|
||||
|
||||
# Test missing template variable
|
||||
with pytest.raises(KeyError, match="Missing required template variable"):
|
||||
agent._interpolate_only("test {missing}", {"topic": "AI"})
|
||||
|
||||
def test_agent_with_all_llm_attributes():
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
|
||||
@@ -2,7 +2,6 @@ import os
|
||||
from time import sleep
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import litellm
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
|
||||
@@ -13,31 +12,6 @@ from crewai.utilities.token_counter_callback import TokenCalcHandler
|
||||
|
||||
|
||||
# TODO: This test fails without print statement, which makes me think that something is happening asynchronously that we need to eventually fix and dive deeper into at a later date
|
||||
@pytest.mark.parametrize("model,location,expected", [
|
||||
("vertex_ai/gemini-2.0-flash", "europe-west4", "europe-west4"),
|
||||
("gpt-4", "europe-west4", None), # Non-vertex model ignores location
|
||||
("vertex_ai/gemini-2.0-flash", None, None), # No location provided
|
||||
])
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_vertex_ai_location_setting(model, location, expected):
|
||||
"""Test Vertex AI location setting behavior."""
|
||||
llm = LLM(model=model, location=location)
|
||||
assert litellm.vertex_location == expected
|
||||
|
||||
# Reset location after test
|
||||
litellm.vertex_location = None
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_vertex_ai_location_validation():
|
||||
"""Test Vertex AI location validation."""
|
||||
# Test invalid location type
|
||||
with pytest.raises(ValueError, match="Location must be a string"):
|
||||
LLM(model="vertex_ai/gemini-2.0-flash", location=123)
|
||||
|
||||
# Test invalid region
|
||||
with pytest.raises(ValueError, match="Invalid Vertex AI region"):
|
||||
LLM(model="vertex_ai/gemini-2.0-flash", location="invalid-region")
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_llm_callback_replacement():
|
||||
llm1 = LLM(model="gpt-4o-mini")
|
||||
|
||||
Reference in New Issue
Block a user