mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-07 15:18:29 +00:00
Compare commits
2 Commits
devin/1747
...
devin/1746
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
5b03d0e0db | ||
|
|
ee308ed322 |
25
.github/workflows/linter.yml
vendored
25
.github/workflows/linter.yml
vendored
@@ -5,29 +5,12 @@ on: [pull_request]
|
||||
jobs:
|
||||
lint:
|
||||
runs-on: ubuntu-latest
|
||||
env:
|
||||
TARGET_BRANCH: ${{ github.event.pull_request.base.ref }}
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
fetch-depth: 0
|
||||
|
||||
- name: Fetch Target Branch
|
||||
run: git fetch origin $TARGET_BRANCH --depth=1
|
||||
|
||||
- name: Install Ruff
|
||||
run: pip install ruff
|
||||
|
||||
- name: Get Changed Python Files
|
||||
id: changed-files
|
||||
- name: Install Requirements
|
||||
run: |
|
||||
merge_base=$(git merge-base origin/"$TARGET_BRANCH" HEAD)
|
||||
changed_files=$(git diff --name-only --diff-filter=ACMRTUB "$merge_base" | grep '\.py$' || true)
|
||||
echo "files<<EOF" >> $GITHUB_OUTPUT
|
||||
echo "$changed_files" >> $GITHUB_OUTPUT
|
||||
echo "EOF" >> $GITHUB_OUTPUT
|
||||
pip install ruff
|
||||
|
||||
- name: Run Ruff on Changed Files
|
||||
if: ${{ steps.changed-files.outputs.files != '' }}
|
||||
run: |
|
||||
echo "${{ steps.changed-files.outputs.files }}" | tr " " "\n" | xargs -I{} ruff check "{}"
|
||||
- name: Run Ruff Linter
|
||||
run: ruff check
|
||||
|
||||
@@ -2,3 +2,8 @@ exclude = [
|
||||
"templates",
|
||||
"__init__.py",
|
||||
]
|
||||
|
||||
[lint]
|
||||
select = [
|
||||
"I", # isort rules
|
||||
]
|
||||
|
||||
@@ -700,11 +700,4 @@ recent_news = SpaceNewsKnowledgeSource(
|
||||
- Configure appropriate embedding models
|
||||
- Consider using local embedding providers for faster processing
|
||||
</Accordion>
|
||||
|
||||
<Accordion title="One Time Knowledge">
|
||||
- With the typical file structure provided by CrewAI, knowledge sources are embedded every time the kickoff is triggered.
|
||||
- If the knowledge sources are large, this leads to inefficiency and increased latency, as the same data is embedded each time.
|
||||
- To resolve this, directly initialize the knowledge parameter instead of the knowledge_sources parameter.
|
||||
- Link to the issue to get complete idea [Github Issue](https://github.com/crewAIInc/crewAI/issues/2755)
|
||||
</Accordion>
|
||||
</AccordionGroup>
|
||||
|
||||
@@ -20,7 +20,6 @@ from crewai.tools.agent_tools.agent_tools import AgentTools
|
||||
from crewai.utilities import Converter, Prompts
|
||||
from crewai.utilities.agent_utils import (
|
||||
get_tool_names,
|
||||
load_agent_from_repository,
|
||||
parse_tools,
|
||||
render_text_description_and_args,
|
||||
)
|
||||
@@ -135,16 +134,6 @@ class Agent(BaseAgent):
|
||||
default=None,
|
||||
description="Knowledge search query for the agent dynamically generated by the agent.",
|
||||
)
|
||||
from_repository: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Agent's role to be used from your repository.",
|
||||
)
|
||||
|
||||
@model_validator(mode="before")
|
||||
def validate_from_repository(cls, v):
|
||||
if v is not None and (from_repository := v.get("from_repository")):
|
||||
return load_agent_from_repository(from_repository) | v
|
||||
return v
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
|
||||
@@ -5,5 +5,5 @@ def get_auth_token() -> str:
|
||||
"""Get the authentication token."""
|
||||
access_token = TokenManager().get_token()
|
||||
if not access_token:
|
||||
raise Exception("No token found, make sure you are logged in")
|
||||
raise Exception()
|
||||
return access_token
|
||||
|
||||
@@ -14,7 +14,6 @@ class PlusAPI:
|
||||
|
||||
TOOLS_RESOURCE = "/crewai_plus/api/v1/tools"
|
||||
CREWS_RESOURCE = "/crewai_plus/api/v1/crews"
|
||||
AGENTS_RESOURCE = "/crewai_plus/api/v1/agents"
|
||||
|
||||
def __init__(self, api_key: str) -> None:
|
||||
self.api_key = api_key
|
||||
@@ -38,9 +37,6 @@ class PlusAPI:
|
||||
def get_tool(self, handle: str):
|
||||
return self._make_request("GET", f"{self.TOOLS_RESOURCE}/{handle}")
|
||||
|
||||
def get_agent(self, handle: str):
|
||||
return self._make_request("GET", f"{self.AGENTS_RESOURCE}/{handle}")
|
||||
|
||||
def publish_tool(
|
||||
self,
|
||||
handle: str,
|
||||
|
||||
@@ -5,7 +5,8 @@ import sys
|
||||
import threading
|
||||
import warnings
|
||||
from collections import defaultdict
|
||||
from contextlib import contextmanager, redirect_stderr, redirect_stdout
|
||||
from contextlib import contextmanager
|
||||
from types import SimpleNamespace
|
||||
from typing import (
|
||||
Any,
|
||||
DefaultDict,
|
||||
@@ -30,6 +31,7 @@ from crewai.utilities.events.llm_events import (
|
||||
LLMCallType,
|
||||
LLMStreamChunkEvent,
|
||||
)
|
||||
from crewai.utilities.events.tool_usage_events import ToolExecutionErrorEvent
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.simplefilter("ignore", UserWarning)
|
||||
@@ -43,9 +45,6 @@ with warnings.catch_warnings():
|
||||
from litellm.utils import supports_response_schema
|
||||
|
||||
|
||||
import io
|
||||
from typing import TextIO
|
||||
|
||||
from crewai.llms.base_llm import BaseLLM
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
@@ -55,17 +54,12 @@ from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
load_dotenv()
|
||||
|
||||
|
||||
class FilteredStream(io.TextIOBase):
|
||||
_lock = None
|
||||
|
||||
def __init__(self, original_stream: TextIO):
|
||||
class FilteredStream:
|
||||
def __init__(self, original_stream):
|
||||
self._original_stream = original_stream
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def write(self, s: str) -> int:
|
||||
if not self._lock:
|
||||
self._lock = threading.Lock()
|
||||
|
||||
def write(self, s) -> int:
|
||||
with self._lock:
|
||||
# Filter out extraneous messages from LiteLLM
|
||||
if (
|
||||
@@ -220,11 +214,15 @@ def suppress_warnings():
|
||||
)
|
||||
|
||||
# Redirect stdout and stderr
|
||||
with (
|
||||
redirect_stdout(FilteredStream(sys.stdout)),
|
||||
redirect_stderr(FilteredStream(sys.stderr)),
|
||||
):
|
||||
old_stdout = sys.stdout
|
||||
old_stderr = sys.stderr
|
||||
sys.stdout = FilteredStream(old_stdout)
|
||||
sys.stderr = FilteredStream(old_stderr)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
sys.stdout = old_stdout
|
||||
sys.stderr = old_stderr
|
||||
|
||||
|
||||
class Delta(TypedDict):
|
||||
@@ -248,6 +246,9 @@ class AccumulatedToolArgs(BaseModel):
|
||||
|
||||
|
||||
class LLM(BaseLLM):
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
GEMINI_IDENTIFIERS = ("gemini", "gemma-")
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str,
|
||||
@@ -321,8 +322,55 @@ class LLM(BaseLLM):
|
||||
Returns:
|
||||
bool: True if the model is from Anthropic, False otherwise.
|
||||
"""
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
if not isinstance(model, str):
|
||||
return False
|
||||
return any(prefix in model.lower() for prefix in self.ANTHROPIC_PREFIXES)
|
||||
|
||||
def _is_gemini_model(self, model: str) -> bool:
|
||||
"""Determine if the model is from Google Gemini provider.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Gemini, False otherwise.
|
||||
"""
|
||||
if not isinstance(model, str):
|
||||
return False
|
||||
return any(identifier in model.lower() for identifier in self.GEMINI_IDENTIFIERS)
|
||||
|
||||
def _normalize_gemini_model(self, model: str) -> str:
|
||||
"""Normalize Gemini model name to the format expected by LiteLLM.
|
||||
|
||||
Handles formats like "models/gemini-pro" or "gemini-pro" and converts
|
||||
them to "gemini/gemini-pro" format.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
str: Normalized model name.
|
||||
|
||||
Raises:
|
||||
ValueError: If model is not a string or is empty.
|
||||
"""
|
||||
if not isinstance(model, str):
|
||||
raise ValueError(f"Model must be a string, got {type(model)}")
|
||||
|
||||
if not model.strip():
|
||||
raise ValueError("Model name cannot be empty")
|
||||
|
||||
if model.startswith("gemini/"):
|
||||
return model
|
||||
|
||||
if model.startswith("models/"):
|
||||
model_name = model.split("/", 1)[1]
|
||||
return f"gemini/{model_name}"
|
||||
|
||||
if self._is_gemini_model(model) and "/" not in model:
|
||||
return f"gemini/{model}"
|
||||
|
||||
return model
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
@@ -345,9 +393,23 @@ class LLM(BaseLLM):
|
||||
messages = [{"role": "user", "content": messages}]
|
||||
formatted_messages = self._format_messages_for_provider(messages)
|
||||
|
||||
# --- 2) Prepare the parameters for the completion call
|
||||
model = self.model
|
||||
if self._is_gemini_model(model):
|
||||
try:
|
||||
model = self._normalize_gemini_model(model)
|
||||
logging.info(f"Normalized Gemini model name from '{self.model}' to '{model}'")
|
||||
|
||||
# --- 2.1) Map GOOGLE_API_KEY to GEMINI_API_KEY if needed
|
||||
if not os.environ.get("GEMINI_API_KEY") and os.environ.get("GOOGLE_API_KEY"):
|
||||
os.environ["GEMINI_API_KEY"] = os.environ["GOOGLE_API_KEY"]
|
||||
logging.info("Mapped GOOGLE_API_KEY to GEMINI_API_KEY for Gemini model")
|
||||
except ValueError as e:
|
||||
logging.error(f"Error normalizing Gemini model: {str(e)}")
|
||||
model = self.model
|
||||
|
||||
# --- 3) Prepare the parameters for the completion call
|
||||
params = {
|
||||
"model": self.model,
|
||||
"model": model,
|
||||
"messages": formatted_messages,
|
||||
"timeout": self.timeout,
|
||||
"temperature": self.temperature,
|
||||
|
||||
@@ -173,18 +173,11 @@ class CrewStructuredTool:
|
||||
def _parse_args(self, raw_args: Union[str, dict]) -> dict:
|
||||
"""Parse and validate the input arguments against the schema.
|
||||
|
||||
This method handles different input formats from various LLM providers,
|
||||
including nested dictionaries with 'value' fields that some providers use.
|
||||
|
||||
Args:
|
||||
raw_args: The raw arguments to parse, either as a string or dict.
|
||||
Supports nested dictionaries with 'value' field for LLM provider compatibility.
|
||||
raw_args: The raw arguments to parse, either as a string or dict
|
||||
|
||||
Returns:
|
||||
The validated arguments as a dictionary
|
||||
|
||||
Raises:
|
||||
ValueError: If argument parsing or validation fails
|
||||
"""
|
||||
if isinstance(raw_args, str):
|
||||
try:
|
||||
@@ -194,31 +187,6 @@ class CrewStructuredTool:
|
||||
except json.JSONDecodeError as e:
|
||||
raise ValueError(f"Failed to parse arguments as JSON: {e}")
|
||||
|
||||
# Handle nested dictionaries with 'value' field for all parameter types
|
||||
if isinstance(raw_args, dict):
|
||||
schema_fields = self.args_schema.model_fields
|
||||
|
||||
for field_name, field_value in list(raw_args.items()):
|
||||
# Check if this field exists in the schema
|
||||
if field_name in schema_fields:
|
||||
# Handle nested dictionaries with 'value' field
|
||||
if isinstance(field_value, dict):
|
||||
if 'value' in field_value:
|
||||
# Extract the value from the nested dictionary
|
||||
value = field_value['value']
|
||||
self._logger.debug(f"Extracting value from nested dict for {field_name}")
|
||||
|
||||
expected_type = schema_fields[field_name].annotation
|
||||
|
||||
if expected_type in (str, int, float, bool) and not isinstance(value, expected_type):
|
||||
self._logger.warning(
|
||||
f"Type mismatch for {field_name}: expected {expected_type}, got {type(value)}"
|
||||
)
|
||||
|
||||
raw_args[field_name] = value
|
||||
else:
|
||||
self._logger.debug(f"Nested dict for {field_name} has no 'value' key")
|
||||
|
||||
try:
|
||||
validated_args = self.args_schema.model_validate(raw_args)
|
||||
return validated_args.model_dump()
|
||||
|
||||
@@ -16,7 +16,6 @@ from crewai.tools.base_tool import BaseTool
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
from crewai.tools.tool_types import ToolResult
|
||||
from crewai.utilities import I18N, Printer
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.exceptions.context_window_exceeding_exception import (
|
||||
LLMContextLengthExceededException,
|
||||
)
|
||||
@@ -429,36 +428,3 @@ def show_agent_logs(
|
||||
printer.print(
|
||||
content=f"\033[95m## Final Answer:\033[00m \033[92m\n{formatted_answer.output}\033[00m\n\n"
|
||||
)
|
||||
|
||||
|
||||
def load_agent_from_repository(from_repository: str) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
if from_repository:
|
||||
import importlib
|
||||
|
||||
from crewai.cli.authentication.token import get_auth_token
|
||||
from crewai.cli.plus_api import PlusAPI
|
||||
|
||||
client = PlusAPI(api_key=get_auth_token())
|
||||
response = client.get_agent(from_repository)
|
||||
if response.status_code != 200:
|
||||
raise AgentRepositoryError(
|
||||
f"Agent {from_repository} could not be loaded: {response.text}"
|
||||
)
|
||||
|
||||
agent = response.json()
|
||||
for key, value in agent.items():
|
||||
if key == "tools":
|
||||
attributes[key] = []
|
||||
for tool_name in value:
|
||||
try:
|
||||
module = importlib.import_module("crewai_tools")
|
||||
tool_class = getattr(module, tool_name)
|
||||
attributes[key].append(tool_class())
|
||||
except Exception as e:
|
||||
raise AgentRepositoryError(
|
||||
f"Tool {tool_name} could not be loaded: {e}"
|
||||
) from e
|
||||
else:
|
||||
attributes[key] = value
|
||||
return attributes
|
||||
|
||||
@@ -1,5 +1,4 @@
|
||||
"""Error message definitions for CrewAI database operations."""
|
||||
|
||||
from typing import Optional
|
||||
|
||||
|
||||
@@ -38,9 +37,3 @@ class DatabaseError:
|
||||
The formatted error message
|
||||
"""
|
||||
return template.format(str(error))
|
||||
|
||||
|
||||
class AgentRepositoryError(Exception):
|
||||
"""Exception raised when an agent repository is not found."""
|
||||
|
||||
...
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import os
|
||||
from unittest import mock
|
||||
from unittest.mock import MagicMock, patch
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -18,7 +18,6 @@ from crewai.tools import tool
|
||||
from crewai.tools.tool_calling import InstructorToolCalling
|
||||
from crewai.tools.tool_usage import ToolUsage
|
||||
from crewai.utilities import RPMController
|
||||
from crewai.utilities.errors import AgentRepositoryError
|
||||
from crewai.utilities.events import crewai_event_bus
|
||||
from crewai.utilities.events.tool_usage_events import ToolUsageFinishedEvent
|
||||
|
||||
@@ -309,7 +308,9 @@ def test_cache_hitting():
|
||||
def handle_tool_end(source, event):
|
||||
received_events.append(event)
|
||||
|
||||
with (patch.object(CacheHandler, "read") as read,):
|
||||
with (
|
||||
patch.object(CacheHandler, "read") as read,
|
||||
):
|
||||
read.return_value = "0"
|
||||
task = Task(
|
||||
description="What is 2 times 6? Ignore correctness and just return the result of the multiplication tool, you must use the tool.",
|
||||
@@ -1039,7 +1040,7 @@ def test_agent_human_input():
|
||||
CrewAgentExecutor,
|
||||
"_invoke_loop",
|
||||
return_value=AgentFinish(output="Hello", thought="", text=""),
|
||||
),
|
||||
) as mock_invoke_loop,
|
||||
):
|
||||
# Execute the task
|
||||
output = agent.execute_task(task)
|
||||
@@ -2024,86 +2025,3 @@ def test_get_knowledge_search_query():
|
||||
},
|
||||
]
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_get_auth_token():
|
||||
with patch(
|
||||
"crewai.cli.authentication.token.get_auth_token", return_value="test_token"
|
||||
):
|
||||
yield
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"role": "test role",
|
||||
"goal": "test goal",
|
||||
"backstory": "test backstory",
|
||||
"tools": ["SerperDevTool"],
|
||||
}
|
||||
mock_get_agent.return_value = mock_get_response
|
||||
agent = Agent(from_repository="test_agent")
|
||||
|
||||
assert agent.role == "test role"
|
||||
assert agent.goal == "test goal"
|
||||
assert agent.backstory == "test backstory"
|
||||
assert len(agent.tools) == 1
|
||||
assert isinstance(agent.tools[0], SerperDevTool)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_override_attributes(mock_get_agent, mock_get_auth_token):
|
||||
from crewai_tools import SerperDevTool
|
||||
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"role": "test role",
|
||||
"goal": "test goal",
|
||||
"backstory": "test backstory",
|
||||
"tools": ["SerperDevTool"],
|
||||
}
|
||||
mock_get_agent.return_value = mock_get_response
|
||||
agent = Agent(from_repository="test_agent", role="Custom Role")
|
||||
|
||||
assert agent.role == "Custom Role"
|
||||
assert agent.goal == "test goal"
|
||||
assert agent.backstory == "test backstory"
|
||||
assert len(agent.tools) == 1
|
||||
assert isinstance(agent.tools[0], SerperDevTool)
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_with_invalid_tools(mock_get_agent, mock_get_auth_token):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 200
|
||||
mock_get_response.json.return_value = {
|
||||
"role": "test role",
|
||||
"goal": "test goal",
|
||||
"backstory": "test backstory",
|
||||
"tools": ["DoesNotExist"],
|
||||
}
|
||||
mock_get_agent.return_value = mock_get_response
|
||||
with pytest.raises(
|
||||
AgentRepositoryError,
|
||||
match="Tool DoesNotExist could not be loaded: module 'crewai_tools' has no attribute 'DoesNotExist'",
|
||||
):
|
||||
Agent(from_repository="test_agent")
|
||||
|
||||
|
||||
@patch("crewai.cli.plus_api.PlusAPI.get_agent")
|
||||
def test_agent_from_repository_agent_not_found(mock_get_agent, mock_get_auth_token):
|
||||
mock_get_response = MagicMock()
|
||||
mock_get_response.status_code = 404
|
||||
mock_get_response.text = "Agent not found"
|
||||
mock_get_agent.return_value = mock_get_response
|
||||
with pytest.raises(
|
||||
AgentRepositoryError,
|
||||
match="Agent NOT_FOUND could not be loaded: Agent not found",
|
||||
):
|
||||
Agent(from_repository="NOT_FOUND")
|
||||
|
||||
@@ -220,6 +220,37 @@ def test_get_custom_llm_provider_gemini():
|
||||
assert llm._get_custom_llm_provider() == "gemini"
|
||||
|
||||
|
||||
def test_is_gemini_model():
|
||||
"""Test the _is_gemini_model method with various model names."""
|
||||
llm = LLM(model="gpt-4") # Model doesn't matter for this test
|
||||
|
||||
assert llm._is_gemini_model("gemini-pro") == True
|
||||
assert llm._is_gemini_model("gemini/gemini-1.5-pro") == True
|
||||
assert llm._is_gemini_model("models/gemini-pro") == True
|
||||
assert llm._is_gemini_model("gemma-7b") == True
|
||||
|
||||
# Should not identify as Gemini models
|
||||
assert llm._is_gemini_model("gpt-4") == False
|
||||
assert llm._is_gemini_model("claude-3") == False
|
||||
assert llm._is_gemini_model("mistral-7b") == False
|
||||
|
||||
|
||||
def test_normalize_gemini_model():
|
||||
"""Test the _normalize_gemini_model method with various model formats."""
|
||||
llm = LLM(model="gpt-4") # Model doesn't matter for this test
|
||||
|
||||
assert llm._normalize_gemini_model("gemini/gemini-1.5-pro") == "gemini/gemini-1.5-pro"
|
||||
|
||||
assert llm._normalize_gemini_model("models/gemini-pro") == "gemini/gemini-pro"
|
||||
assert llm._normalize_gemini_model("models/gemini-1.5-flash") == "gemini/gemini-1.5-flash"
|
||||
|
||||
assert llm._normalize_gemini_model("gemini-pro") == "gemini/gemini-pro"
|
||||
assert llm._normalize_gemini_model("gemini-1.5-flash") == "gemini/gemini-1.5-flash"
|
||||
|
||||
assert llm._normalize_gemini_model("gpt-4") == "gpt-4"
|
||||
assert llm._normalize_gemini_model("claude-3") == "claude-3"
|
||||
|
||||
|
||||
def test_get_custom_llm_provider_openai():
|
||||
llm = LLM(model="gpt-4")
|
||||
assert llm._get_custom_llm_provider() == None
|
||||
@@ -274,6 +305,82 @@ def test_gemini_models(model):
|
||||
assert "Paris" in result
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
[
|
||||
"models/gemini-pro", # Format from issue #2803
|
||||
"gemini-pro", # Format without provider prefix
|
||||
],
|
||||
)
|
||||
def test_gemini_model_normalization(model):
|
||||
"""Test that different Gemini model formats are normalized correctly."""
|
||||
llm = LLM(model=model)
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create mocks for response structure
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Paris"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
# Set up the mocked completion to return the mock response
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
# Check that the model was normalized correctly in the call to litellm
|
||||
args, kwargs = mock_completion.call_args
|
||||
assert kwargs["model"].startswith("gemini/")
|
||||
assert "gemini-pro" in kwargs["model"]
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
def test_gemini_api_key_mapping():
|
||||
"""Test that GOOGLE_API_KEY is mapped to GEMINI_API_KEY for Gemini models."""
|
||||
original_google_api_key = os.environ.get("GOOGLE_API_KEY")
|
||||
original_gemini_api_key = os.environ.get("GEMINI_API_KEY")
|
||||
|
||||
try:
|
||||
# Set up test environment
|
||||
test_api_key = "test_google_api_key"
|
||||
os.environ["GOOGLE_API_KEY"] = test_api_key
|
||||
if "GEMINI_API_KEY" in os.environ:
|
||||
del os.environ["GEMINI_API_KEY"]
|
||||
|
||||
llm = LLM(model="gemini-pro")
|
||||
|
||||
with patch("litellm.completion") as mock_completion:
|
||||
# Create mocks for response structure
|
||||
mock_message = MagicMock()
|
||||
mock_message.content = "Paris"
|
||||
mock_choice = MagicMock()
|
||||
mock_choice.message = mock_message
|
||||
mock_response = MagicMock()
|
||||
mock_response.choices = [mock_choice]
|
||||
|
||||
# Set up the mocked completion to return the mock response
|
||||
mock_completion.return_value = mock_response
|
||||
|
||||
llm.call("What is the capital of France?")
|
||||
|
||||
# Check that GEMINI_API_KEY was set from GOOGLE_API_KEY
|
||||
assert os.environ.get("GEMINI_API_KEY") == test_api_key
|
||||
|
||||
finally:
|
||||
if original_google_api_key is not None:
|
||||
os.environ["GOOGLE_API_KEY"] = original_google_api_key
|
||||
else:
|
||||
os.environ.pop("GOOGLE_API_KEY", None)
|
||||
|
||||
if original_gemini_api_key is not None:
|
||||
os.environ["GEMINI_API_KEY"] = original_gemini_api_key
|
||||
else:
|
||||
os.environ.pop("GEMINI_API_KEY", None)
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"], filter_query_parameters=["key"])
|
||||
@pytest.mark.parametrize(
|
||||
"model",
|
||||
|
||||
@@ -1,173 +0,0 @@
|
||||
import pytest
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from crewai.tools.structured_tool import CrewStructuredTool
|
||||
|
||||
|
||||
class StringInputSchema(BaseModel):
|
||||
"""Schema with a string input field."""
|
||||
query: str = Field(description="A string input parameter")
|
||||
|
||||
|
||||
class IntInputSchema(BaseModel):
|
||||
"""Schema with an integer input field."""
|
||||
number: int = Field(description="An integer input parameter")
|
||||
|
||||
|
||||
class ComplexInputSchema(BaseModel):
|
||||
"""Schema with multiple fields of different types."""
|
||||
text: str = Field(description="A string parameter")
|
||||
number: int = Field(description="An integer parameter")
|
||||
flag: bool = Field(description="A boolean parameter")
|
||||
|
||||
|
||||
def test_parse_args_with_string_input():
|
||||
"""Test that string inputs are parsed correctly."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with direct string input
|
||||
result = tool._parse_args({"query": "test string"})
|
||||
assert result["query"] == "test string"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
# Test with JSON string input
|
||||
result = tool._parse_args('{"query": "json string"}')
|
||||
assert result["query"] == "json string"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
|
||||
def test_parse_args_with_nested_dict_for_string():
|
||||
"""Test that nested dictionaries with 'value' field are handled correctly for string fields."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with nested dict input (simulating the issue from different LLM providers)
|
||||
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
|
||||
result = tool._parse_args(nested_input)
|
||||
assert result["query"] == "test value"
|
||||
assert isinstance(result["query"], str)
|
||||
|
||||
|
||||
def test_parse_args_with_nested_dict_for_int():
|
||||
"""Test that nested dictionaries with 'value' field are handled correctly for int fields."""
|
||||
def test_func(number: int) -> str:
|
||||
return f"Processed: {number}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="IntTool",
|
||||
description="A tool that processes integer input"
|
||||
)
|
||||
|
||||
# Test with nested dict input for int field
|
||||
nested_input = {"number": {"description": "An integer input parameter", "value": 42}}
|
||||
result = tool._parse_args(nested_input)
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["number"], int)
|
||||
|
||||
|
||||
def test_parse_args_with_complex_input():
|
||||
"""Test that complex inputs with multiple fields are handled correctly."""
|
||||
def test_func(text: str, number: int, flag: bool) -> str:
|
||||
return f"Processed: {text}, {number}, {flag}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="ComplexTool",
|
||||
description="A tool that processes complex input"
|
||||
)
|
||||
|
||||
# Test with mixed nested dict input
|
||||
complex_input = {
|
||||
"text": {"description": "A string parameter", "value": "test text"},
|
||||
"number": 42,
|
||||
"flag": True
|
||||
}
|
||||
result = tool._parse_args(complex_input)
|
||||
assert result["text"] == "test text"
|
||||
assert isinstance(result["text"], str)
|
||||
assert result["number"] == 42
|
||||
assert isinstance(result["number"], int)
|
||||
assert result["flag"] is True
|
||||
assert isinstance(result["flag"], bool)
|
||||
|
||||
|
||||
def test_invoke_with_nested_dict():
|
||||
"""Test that invoking a tool with nested dict input works correctly."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test invoking with nested dict input
|
||||
nested_input = {"query": {"description": "A string input parameter", "value": "test value"}}
|
||||
result = tool.invoke(nested_input)
|
||||
assert result == "Processed: test value"
|
||||
|
||||
|
||||
def test_nested_dict_without_value_key():
|
||||
"""Test that nested dictionaries without 'value' field raise appropriate errors."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with nested dict without 'value' key
|
||||
invalid_input = {"query": {"description": "A string input parameter", "other_key": "test"}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(invalid_input)
|
||||
|
||||
|
||||
def test_empty_nested_dict():
|
||||
"""Test handling of empty nested dictionaries."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with empty nested dict
|
||||
empty_dict_input = {"query": {}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(empty_dict_input)
|
||||
|
||||
|
||||
def test_deeply_nested_structure():
|
||||
"""Test handling of deeply nested structures."""
|
||||
def test_func(query: str) -> str:
|
||||
return f"Processed: {query}"
|
||||
|
||||
tool = CrewStructuredTool.from_function(
|
||||
func=test_func,
|
||||
name="StringTool",
|
||||
description="A tool that processes string input"
|
||||
)
|
||||
|
||||
# Test with deeply nested structure
|
||||
deeply_nested = {"query": {"nested": {"deeper": {"value": "deep value"}}}}
|
||||
with pytest.raises(ValueError):
|
||||
tool._parse_args(deeply_nested)
|
||||
Reference in New Issue
Block a user