mirror of
https://github.com/crewAIInc/crewAI.git
synced 2025-12-16 12:28:30 +00:00
Compare commits
4 Commits
devin/1740
...
devin/1735
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
56fb691b8f | ||
|
|
4c3253e800 | ||
|
|
8517a1462a | ||
|
|
452aa9f173 |
@@ -137,6 +137,18 @@ class Agent(BaseAgent):
|
||||
|
||||
@model_validator(mode="after")
|
||||
def post_init_setup(self):
|
||||
# Handle case-insensitive LLM parameter
|
||||
if hasattr(self, 'LLM'):
|
||||
import warnings
|
||||
warnings.warn(
|
||||
"Using 'LLM' parameter is deprecated. Use lowercase 'llm' instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2
|
||||
)
|
||||
# Transfer LLM value to llm
|
||||
self.llm = getattr(self, 'LLM')
|
||||
delattr(self, 'LLM')
|
||||
|
||||
self._set_knowledge()
|
||||
self.agent_ops_agent_name = self.role
|
||||
unaccepted_attributes = [
|
||||
@@ -144,6 +156,9 @@ class Agent(BaseAgent):
|
||||
"AWS_SECRET_ACCESS_KEY",
|
||||
"AWS_REGION_NAME",
|
||||
]
|
||||
|
||||
# Initialize LLM parameters
|
||||
llm_params: Dict[str, Any] = {}
|
||||
|
||||
# Handle different cases for self.llm
|
||||
if isinstance(self.llm, str):
|
||||
@@ -190,7 +205,71 @@ class Agent(BaseAgent):
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
# Convert environment variables to proper types
|
||||
try:
|
||||
param_value = None
|
||||
|
||||
# Integer parameters
|
||||
if key in ['timeout', 'max_tokens', 'n', 'max_completion_tokens']:
|
||||
try:
|
||||
param_value = int(str(value)) if value else None
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Float parameters
|
||||
elif key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
|
||||
try:
|
||||
param_value = float(str(value)) if value else None
|
||||
except (ValueError, TypeError):
|
||||
continue
|
||||
|
||||
# Boolean parameters
|
||||
elif key == 'logprobs':
|
||||
if isinstance(value, bool):
|
||||
param_value = value
|
||||
elif isinstance(value, str):
|
||||
param_value = value.lower() == 'true'
|
||||
|
||||
# Dict parameters
|
||||
elif key == 'logit_bias' and value:
|
||||
try:
|
||||
if isinstance(value, dict):
|
||||
param_value = {int(k): float(v) for k, v in value.items()}
|
||||
elif isinstance(value, str):
|
||||
import json
|
||||
bias_dict = json.loads(value)
|
||||
param_value = {int(k): float(v) for k, v in bias_dict.items()}
|
||||
except (ValueError, TypeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
elif key == 'response_format' and value:
|
||||
try:
|
||||
if isinstance(value, dict):
|
||||
param_value = value
|
||||
elif isinstance(value, str):
|
||||
import json
|
||||
param_value = json.loads(value)
|
||||
except (ValueError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
# List parameters
|
||||
elif key == 'callbacks':
|
||||
if isinstance(value, (list, tuple)):
|
||||
param_value = list(value)
|
||||
elif isinstance(value, str):
|
||||
param_value = [cb.strip() for cb in value.split(',') if cb.strip()]
|
||||
else:
|
||||
param_value = []
|
||||
|
||||
# String and other parameters
|
||||
else:
|
||||
param_value = value
|
||||
|
||||
if param_value is not None:
|
||||
llm_params[key] = param_value
|
||||
except Exception:
|
||||
# Skip any invalid values
|
||||
continue
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
else:
|
||||
|
||||
65
tests/test_agent.py
Normal file
65
tests/test_agent.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import os
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
|
||||
from crewai.agent import Agent
|
||||
from crewai.llm import LLM
|
||||
|
||||
|
||||
def test_agent_with_custom_llm():
|
||||
"""Test creating an agent with a custom LLM."""
|
||||
custom_llm = LLM(model="gpt-4")
|
||||
agent = Agent()
|
||||
agent.role = "test"
|
||||
agent.goal = "test"
|
||||
agent.backstory = "test"
|
||||
agent.llm = custom_llm
|
||||
agent.allow_delegation = False
|
||||
agent.post_init_setup()
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-4"
|
||||
|
||||
def test_agent_with_uppercase_llm_param():
|
||||
"""Test creating an agent with uppercase 'LLM' parameter."""
|
||||
custom_llm = LLM(model="gpt-4")
|
||||
with pytest.warns(DeprecationWarning):
|
||||
agent = Agent()
|
||||
agent.role = "test"
|
||||
agent.goal = "test"
|
||||
agent.backstory = "test"
|
||||
setattr(agent, 'LLM', custom_llm) # Using uppercase LLM
|
||||
agent.allow_delegation = False
|
||||
agent.post_init_setup()
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.model == "gpt-4"
|
||||
assert not hasattr(agent, 'LLM')
|
||||
|
||||
def test_agent_llm_parameter_types():
|
||||
"""Test LLM parameter type handling."""
|
||||
env_vars = {
|
||||
"temperature": "0.7",
|
||||
"max_tokens": "100",
|
||||
"presence_penalty": "0.5",
|
||||
"logprobs": "true",
|
||||
"logit_bias": '{"50256": -100}',
|
||||
"callbacks": "callback1,callback2",
|
||||
}
|
||||
with mock.patch.dict(os.environ, env_vars):
|
||||
agent = Agent()
|
||||
agent.role = "test"
|
||||
agent.goal = "test"
|
||||
agent.backstory = "test"
|
||||
agent.llm = "gpt-4"
|
||||
agent.allow_delegation = False
|
||||
agent.post_init_setup()
|
||||
|
||||
assert isinstance(agent.llm, LLM)
|
||||
assert agent.llm.temperature == 0.7
|
||||
assert agent.llm.max_tokens == 100
|
||||
assert agent.llm.presence_penalty == 0.5
|
||||
assert agent.llm.logprobs is True
|
||||
assert agent.llm.logit_bias == {50256: -100.0}
|
||||
assert agent.llm.callbacks == ["callback1", "callback2"]
|
||||
Reference in New Issue
Block a user