mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
fix: Improve type conversion for LLM parameters and handle None values properly
Co-Authored-By: Joe Moura <joao@crewai.com>
This commit is contained in:
@@ -190,9 +190,71 @@ class Agent(BaseAgent):
|
||||
if key not in ["prompt", "key_name", "default"]:
|
||||
# Only add default if the key is already set in os.environ
|
||||
if key in os.environ:
|
||||
llm_params[key] = value
|
||||
try:
|
||||
# Create a new dictionary for properly typed parameters
|
||||
typed_params = {}
|
||||
|
||||
# Convert and validate values based on parameter type
|
||||
if key in ['temperature', 'top_p', 'presence_penalty', 'frequency_penalty']:
|
||||
if value is not None:
|
||||
try:
|
||||
typed_params[key] = float(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif key in ['n', 'max_tokens', 'max_completion_tokens', 'seed']:
|
||||
if value is not None:
|
||||
try:
|
||||
typed_params[key] = int(value)
|
||||
except (ValueError, TypeError):
|
||||
pass
|
||||
elif key == 'logit_bias' and isinstance(value, str):
|
||||
try:
|
||||
bias_dict = {}
|
||||
for pair in value.split(','):
|
||||
token_id, bias = pair.split(':')
|
||||
bias_dict[int(token_id.strip())] = float(bias.strip())
|
||||
typed_params[key] = bias_dict
|
||||
except (ValueError, AttributeError):
|
||||
pass
|
||||
elif key == 'response_format' and isinstance(value, str):
|
||||
try:
|
||||
import json
|
||||
typed_params[key] = json.loads(value)
|
||||
except json.JSONDecodeError:
|
||||
pass
|
||||
elif key == 'logprobs':
|
||||
if value is not None:
|
||||
typed_params[key] = bool(value.lower() == 'true') if isinstance(value, str) else bool(value)
|
||||
elif key == 'callbacks':
|
||||
typed_params[key] = [] if value is None else [value] if isinstance(value, str) else value
|
||||
elif key == 'stop':
|
||||
typed_params[key] = [value] if isinstance(value, str) else value
|
||||
elif key in ['model', 'base_url', 'api_version', 'api_key']:
|
||||
typed_params[key] = value
|
||||
|
||||
# Update llm_params with properly typed values
|
||||
if typed_params:
|
||||
llm_params.update(typed_params)
|
||||
except (ValueError, AttributeError, json.JSONDecodeError):
|
||||
continue
|
||||
|
||||
self.llm = LLM(**llm_params)
|
||||
# Create LLM instance with properly typed parameters
|
||||
valid_params = {
|
||||
'model', 'timeout', 'temperature', 'top_p', 'n', 'stop',
|
||||
'max_completion_tokens', 'max_tokens', 'presence_penalty',
|
||||
'frequency_penalty', 'logit_bias', 'response_format',
|
||||
'seed', 'logprobs', 'top_logprobs', 'base_url',
|
||||
'api_version', 'api_key', 'callbacks'
|
||||
}
|
||||
|
||||
# Filter out None values and invalid parameters
|
||||
filtered_params = {}
|
||||
for k, v in llm_params.items():
|
||||
if k in valid_params and v is not None:
|
||||
filtered_params[k] = v
|
||||
|
||||
# Create LLM instance with properly typed parameters
|
||||
self.llm = LLM(**filtered_params)
|
||||
else:
|
||||
# For any other type, attempt to extract relevant attributes
|
||||
llm_params = {
|
||||
|
||||
@@ -642,9 +642,10 @@ def test_task_tools_override_agent_tools():
|
||||
crew.kickoff()
|
||||
|
||||
# Verify task tools override agent tools
|
||||
assert len(task.tools) == 1 # AnotherTestTool
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in task.tools)
|
||||
assert not any(isinstance(tool, TestTool) for tool in task.tools)
|
||||
tools = task.tools or []
|
||||
assert len(tools) == 1 # AnotherTestTool
|
||||
assert any(isinstance(tool, AnotherTestTool) for tool in tools)
|
||||
assert not any(isinstance(tool, TestTool) for tool in tools)
|
||||
|
||||
# Verify agent tools remain unchanged
|
||||
assert len(new_researcher.tools) == 1
|
||||
|
||||
Reference in New Issue
Block a user