fix: Improve type conversion for LLM parameters and handle None values properly

Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
Devin AI
2024-12-31 20:45:25 +00:00
parent f75b07ce82
commit dec255e87a
2 changed files with 68 additions and 5 deletions

View File

@@ -190,9 +190,71 @@ class Agent(BaseAgent):
if key not in ["prompt", "key_name", "default"]:
# Only add default if the key is already set in os.environ
if key in os.environ:
llm_params[key] = value
try:
# Create a new dictionary for properly typed parameters
typed_params = {}
# Convert and validate values based on parameter type
if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
if value is not None:
try:
typed_params[key] = float(value)
except (ValueError, TypeError):
pass
elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']:
if value is not None:
try:
typed_params[key] = int(value)
except (ValueError, TypeError):
pass
elif key == 'logit_bias' and isinstance(value, str):
try:
bias_dict = {}
for pair in value.split(','):
token_id, bias = pair.split(':')
bias_dict[int(token_id.strip())] = float(bias.strip())
typed_params[key] = bias_dict
except (ValueError, AttributeError):
pass
elif key == 'response_format' and isinstance(value, str):
try:
import json
typed_params[key] = json.loads(value)
except json.JSONDecodeError:
pass
elif key == 'logprobs':
if value is not None:
typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value)
elif key == 'callbacks':
typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value
elif key == 'stop':
typed_params[key] = [value] if isinstance(value, str) else value
elif key in ['model', 'base_url', 'api_version', 'api_key']:
typed_params[key] = value
# Update llm_params with properly typed values
if typed_params:
llm_params.update(typed_params)
except (ValueError, AttributeError, json.JSONDecodeError):
continue
self.llm = LLM(**llm_params)
# Create LLM instance with properly typed parameters
valid_params = {
'model', 'timeout', 'temperature', 'top_p', 'n', 'stop',
'max_completion_tokens', 'max_tokens', 'presence_penalty',
'frequency_penalty', 'logit_bias', 'response_format',
'seed', 'logprobs', 'top_logprobs', 'base_url',
'api_version', 'api_key', 'callbacks'
}
# Filter out None values and invalid parameters
filtered_params = {}
for k, v in llm_params.items():
if k in valid_params and v is not None:
filtered_params[k] = v
# Create LLM instance with properly typed parameters
self.llm = LLM(**filtered_params)
else:
# For any other type, attempt to extract relevant attributes
llm_params = {

View File

@@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools():
crew.kickoff()
# Verify task tools override agent tools
assert len(task.tools) == 1 # AnotherTestTool
assert any(isinstance(tool, AnotherTestTool) for tool in task.tools)
assert not any(isinstance(tool, TestTool) for tool in task.tools)
tools = task.tools or []
assert len(tools) == 1 # AnotherTestTool
assert any(isinstance(tool, AnotherTestTool) for tool in tools)
assert not any(isinstance(tool, TestTool) for tool in tools)
# Verify agent tools remain unchanged
assert len(new_researcher.tools) == 1