diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 999d1d800..6617f258b 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -137,6 +137,18 @@ class Agent(BaseAgent): @model_validator(mode="after") def post_init_setup(self): + # Handle case-insensitive LLM parameter + if hasattr(self, 'LLM'): + import warnings + warnings.warn( + "Using 'LLM' parameter is deprecated. Use lowercase 'llm' instead.", + DeprecationWarning, + stacklevel=2 + ) + # Transfer LLM value to llm + self.llm = getattr(self, 'LLM') + delattr(self, 'LLM') + self._set_knowledge() self.agent_ops_agent_name = self.role unaccepted_attributes = [ @@ -144,6 +156,10 @@ class Agent(BaseAgent): "AWS_SECRET_ACCESS_KEY", "AWS_REGION_NAME", ] + + # Initialize LLM parameters with proper typing + from typing import Any, Dict, List, Union, Optional + llm_params: Dict[str, Any] = {} # Handle different cases for self.llm if isinstance(self.llm, str): @@ -190,7 +206,71 @@ class Agent(BaseAgent): if key not in ["prompt", "key_name", "default"]: # Only add default if the key is already set in os.environ if key in os.environ: - llm_params[key] = value + # Convert environment variables to proper types + try: + param_value = None + + # Integer parameters + if key in ['timeout', 'max_tokens', 'n', 'max_completion_tokens']: + try: + param_value = int(str(value)) if value else None + except (ValueError, TypeError): + continue + + # Float parameters + elif key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']: + try: + param_value = float(str(value)) if value else None + except (ValueError, TypeError): + continue + + # Boolean parameters + elif key == 'logprobs': + if isinstance(value, bool): + param_value = value + elif isinstance(value, str): + param_value = value.lower() == 'true' + + # Dict parameters + elif key == 'logit_bias' and value: + try: + if isinstance(value, dict): + param_value = {int(k): float(v) for k, v in value.items()} + elif isinstance(value, str): + import json + bias_dict = json.loads(value) + param_value = {int(k): float(v) for k, v in bias_dict.items()} + except (ValueError, TypeError, json.JSONDecodeError): + continue + + elif key == 'response_format' and value: + try: + if isinstance(value, dict): + param_value = value + elif isinstance(value, str): + import json + param_value = json.loads(value) + except (ValueError, json.JSONDecodeError): + continue + + # List parameters + elif key == 'callbacks': + if isinstance(value, (list, tuple)): + param_value = list(value) + elif isinstance(value, str): + param_value = [cb.strip() for cb in value.split(',') if cb.strip()] + else: + param_value = [] + + # String and other parameters + else: + param_value = value + + if param_value is not None: + llm_params[key] = param_value + except Exception: + # Skip any invalid values + continue self.llm = LLM(**llm_params) else: diff --git a/tests/test_agent.py b/tests/test_agent.py new file mode 100644 index 000000000..eaad771af --- /dev/null +++ b/tests/test_agent.py @@ -0,0 +1,63 @@ +import os +import pytest +from unittest import mock + +from crewai.agent import Agent +from crewai.llm import LLM + +def test_agent_with_custom_llm(): + """Test creating an agent with a custom LLM.""" + custom_llm = LLM(model="gpt-4") + agent = Agent() + agent.role = "test" + agent.goal = "test" + agent.backstory = "test" + agent.llm = custom_llm + agent.allow_delegation = False + agent.post_init_setup() + + assert isinstance(agent.llm, LLM) + assert agent.llm.model == "gpt-4" + +def test_agent_with_uppercase_llm_param(): + """Test creating an agent with uppercase 'LLM' parameter.""" + custom_llm = LLM(model="gpt-4") + with pytest.warns(DeprecationWarning): + agent = Agent() + agent.role = "test" + agent.goal = "test" + agent.backstory = "test" + setattr(agent, 'LLM', custom_llm) # Using uppercase LLM + agent.allow_delegation = False + agent.post_init_setup() + + assert isinstance(agent.llm, LLM) + assert agent.llm.model == "gpt-4" + assert not hasattr(agent, 'LLM') + +def test_agent_llm_parameter_types(): + """Test LLM parameter type handling.""" + env_vars = { + "temperature": "0.7", + "max_tokens": "100", + "presence_penalty": "0.5", + "logprobs": "true", + "logit_bias": '{"50256": -100}', + "callbacks": "callback1,callback2", + } + with mock.patch.dict(os.environ, env_vars): + agent = Agent() + agent.role = "test" + agent.goal = "test" + agent.backstory = "test" + agent.llm = "gpt-4" + agent.allow_delegation = False + agent.post_init_setup() + + assert isinstance(agent.llm, LLM) + assert agent.llm.temperature == 0.7 + assert agent.llm.max_tokens == 100 + assert agent.llm.presence_penalty == 0.5 + assert agent.llm.logprobs is True + assert agent.llm.logit_bias == {50256: -100.0} + assert agent.llm.callbacks == ["callback1", "callback2"]