diff --git a/src/crewai/agent.py b/src/crewai/agent.py index 999d1d800..1383ae43e 100644 --- a/src/crewai/agent.py +++ b/src/crewai/agent.py @@ -190,9 +190,71 @@ class Agent(BaseAgent): if key not in ["prompt", "key_name", "default"]: # Only add default if the key is already set in os.environ if key in os.environ: - llm_params[key] = value + try: + # Create a new dictionary for properly typed parameters + typed_params = {} + + # Convert and validate values based on parameter type + if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']: + if value is not None: + try: + typed_params[key] = float(value) + except (ValueError, TypeError): + pass + elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']: + if value is not None: + try: + typed_params[key] = int(value) + except (ValueError, TypeError): + pass + elif key == 'logit_bias' and isinstance(value, str): + try: + bias_dict = {} + for pair in value.split(','): + token_id, bias = pair.split(':') + bias_dict[int(token_id.strip())] = float(bias.strip()) + typed_params[key] = bias_dict + except (ValueError, AttributeError): + pass + elif key == 'response_format' and isinstance(value, str): + try: + import json + typed_params[key] = json.loads(value) + except json.JSONDecodeError: + pass + elif key == 'logprobs': + if value is not None: + typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value) + elif key == 'callbacks': + typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value + elif key == 'stop': + typed_params[key] = [value] if isinstance(value, str) else value + elif key in ['model', 'base_url', 'api_version', 'api_key']: + typed_params[key] = value + + # Update llm_params with properly typed values + if typed_params: + llm_params.update(typed_params) + except (ValueError, AttributeError, json.JSONDecodeError): + continue - self.llm = LLM(**llm_params) + # Create LLM instance with properly typed parameters + valid_params = { + 'model', 'timeout', 'temperature', 'top_p', 'n', 'stop', + 'max_completion_tokens', 'max_tokens', 'presence_penalty', + 'frequency_penalty', 'logit_bias', 'response_format', + 'seed', 'logprobs', 'top_logprobs', 'base_url', + 'api_version', 'api_key', 'callbacks' + } + + # Filter out None values and invalid parameters + filtered_params = {} + for k, v in llm_params.items(): + if k in valid_params and v is not None: + filtered_params[k] = v + + # Create LLM instance with properly typed parameters + self.llm = LLM(**filtered_params) else: # For any other type, attempt to extract relevant attributes llm_params = { diff --git a/tests/crew_test.py b/tests/crew_test.py index 0cb8f469c..59fc1615f 100644 --- a/tests/crew_test.py +++ b/tests/crew_test.py @@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools(): crew.kickoff() # Verify task tools override agent tools - assert len(task.tools) == 1 # AnotherTestTool - assert any(isinstance(tool, AnotherTestTool) for tool in task.tools) - assert not any(isinstance(tool, TestTool) for tool in task.tools) + tools = task.tools or [] + assert len(tools) == 1 # AnotherTestTool + assert any(isinstance(tool, AnotherTestTool) for tool in tools) + assert not any(isinstance(tool, TestTool) for tool in tools) # Verify agent tools remain unchanged assert len(new_researcher.tools) == 1