mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-08 07:38:29 +00:00
Compare commits
1 Commits
devin/1745
...
devin/1745
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
819bd0b3b2 |
@@ -196,9 +196,11 @@ class Agent(BaseAgent):
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
"model": getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or str(self.llm),
|
||||
"model": self._normalize_model_name(
|
||||
getattr(self.llm, "model_name", None)
|
||||
or getattr(self.llm, "deployment_name", None)
|
||||
or str(self.llm)
|
||||
),
|
||||
"temperature": getattr(self.llm, "temperature", None),
|
||||
"max_tokens": getattr(self.llm, "max_tokens", None),
|
||||
"logprobs": getattr(self.llm, "logprobs", None),
|
||||
@@ -534,5 +536,14 @@ class Agent(BaseAgent):
|
||||
def __tools_names(tools) -> str:
|
||||
return ", ".join([t.name for t in tools])
|
||||
|
||||
def _normalize_model_name(self, model_name):
|
||||
"""
|
||||
Normalize the model name by removing any 'models/' prefix.
|
||||
This fixes the issue with ChatGoogleGenerativeAI and potentially other LLM providers.
|
||||
"""
|
||||
if model_name and isinstance(model_name, str) and model_name.startswith("models/"):
|
||||
return model_name[7:] # Remove "models/" prefix
|
||||
return model_name
|
||||
|
||||
def __repr__(self):
|
||||
return f"Agent(role={self.role}, goal={self.goal}, backstory={self.backstory})"
|
||||
|
||||
@@ -12,7 +12,7 @@ from crewai.utilities.task_output_storage_handler import TaskOutputStorageHandle
|
||||
def reset_memories_command(
|
||||
long,
|
||||
short,
|
||||
entities, # Changed from entity to entities to match CLI parameter
|
||||
entity,
|
||||
knowledge,
|
||||
kickoff_outputs,
|
||||
all,
|
||||
@@ -23,7 +23,7 @@ def reset_memories_command(
|
||||
Args:
|
||||
long (bool): Whether to reset the long-term memory.
|
||||
short (bool): Whether to reset the short-term memory.
|
||||
entities (bool): Whether to reset the entity memory.
|
||||
entity (bool): Whether to reset the entity memory.
|
||||
kickoff_outputs (bool): Whether to reset the latest kickoff task outputs.
|
||||
all (bool): Whether to reset all memories.
|
||||
knowledge (bool): Whether to reset the knowledge.
|
||||
@@ -45,7 +45,7 @@ def reset_memories_command(
|
||||
if short:
|
||||
ShortTermMemory().reset()
|
||||
click.echo("Short term memory has been reset.")
|
||||
if entities: # Changed from entity to entities
|
||||
if entity:
|
||||
EntityMemory().reset()
|
||||
click.echo("Entity memory has been reset.")
|
||||
if kickoff_outputs:
|
||||
|
||||
@@ -1,62 +0,0 @@
|
||||
import os
|
||||
import tempfile
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
import pytest
|
||||
from click.testing import CliRunner
|
||||
|
||||
from crewai.cli.cli import reset_memories
|
||||
from crewai.cli.reset_memories_command import reset_memories_command
|
||||
|
||||
|
||||
def test_reset_memories_command_parameters():
|
||||
"""Test that the CLI parameters match the function parameters."""
|
||||
# Create a mock for reset_memories_command
|
||||
with patch('crewai.cli.cli.reset_memories_command') as mock_reset:
|
||||
runner = CliRunner()
|
||||
|
||||
# Test with entities flag
|
||||
result = runner.invoke(reset_memories, ['--entities'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Check that the function was called with the correct parameters
|
||||
# The third parameter should be True for entities
|
||||
mock_reset.assert_called_once_with(False, False, True, False, False, False)
|
||||
|
||||
|
||||
def test_reset_memories_all_flag():
|
||||
"""Test that the --all flag resets all memories."""
|
||||
with patch('crewai.cli.cli.reset_memories_command') as mock_reset:
|
||||
runner = CliRunner()
|
||||
|
||||
# Test with all flag
|
||||
result = runner.invoke(reset_memories, ['--all'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Check that the function was called with the correct parameters
|
||||
# The last parameter should be True for all
|
||||
mock_reset.assert_called_once_with(False, False, False, False, False, True)
|
||||
|
||||
|
||||
def test_reset_memories_knowledge_flag():
|
||||
"""Test that the --knowledge flag resets knowledge storage."""
|
||||
with patch('crewai.cli.cli.reset_memories_command') as mock_reset:
|
||||
runner = CliRunner()
|
||||
|
||||
# Test with knowledge flag
|
||||
result = runner.invoke(reset_memories, ['--knowledge'])
|
||||
assert result.exit_code == 0
|
||||
|
||||
# Check that the function was called with the correct parameters
|
||||
# The fourth parameter should be True for knowledge
|
||||
mock_reset.assert_called_once_with(False, False, False, True, False, False)
|
||||
|
||||
|
||||
def test_reset_memories_no_flags():
|
||||
"""Test that an error message is shown when no flags are provided."""
|
||||
runner = CliRunner()
|
||||
|
||||
# Test with no flags
|
||||
result = runner.invoke(reset_memories, [])
|
||||
assert result.exit_code == 0
|
||||
assert "Please specify at least one memory type" in result.output
|
||||
39
tests/test_agent_model_name.py
Normal file
39
tests/test_agent_model_name.py
Normal file
@@ -0,0 +1,39 @@
|
||||
import pytest
|
||||
from unittest.mock import MagicMock, patch
|
||||
from crewai import Agent
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def test_normalize_model_name_method():
|
||||
"""Test that the _normalize_model_name method correctly handles model names with 'models/' prefix"""
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4"
|
||||
)
|
||||
|
||||
model_with_prefix = "models/gemini/gemini-1.5-flash"
|
||||
normalized_name = agent._normalize_model_name(model_with_prefix)
|
||||
assert normalized_name == "gemini/gemini-1.5-flash"
|
||||
|
||||
regular_model = "gpt-4"
|
||||
assert agent._normalize_model_name(regular_model) == "gpt-4"
|
||||
|
||||
assert agent._normalize_model_name(None) is None
|
||||
|
||||
assert agent._normalize_model_name(123) == 123
|
||||
|
||||
|
||||
def test_agent_with_regular_model_name():
|
||||
"""Test that the Agent class doesn't modify normal model names"""
|
||||
with patch('crewai.agent.LLM') as mock_llm:
|
||||
agent = Agent(
|
||||
role="Test Agent",
|
||||
goal="Test goal",
|
||||
backstory="Test backstory",
|
||||
llm="gpt-4"
|
||||
)
|
||||
|
||||
args, kwargs = mock_llm.call_args
|
||||
assert kwargs["model"] == "gpt-4"
|
||||
Reference in New Issue
Block a user