mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 08:08:32 +00:00
Revert "working around OAI new update for now"
This reverts commit 23a16eb446.
This commit is contained in:
@@ -1,8 +1,4 @@
|
|||||||
import portalocker
|
import portalocker
|
||||||
import os
|
|
||||||
import tempfile
|
|
||||||
from contextlib import contextmanager
|
|
||||||
from copy import deepcopy
|
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from typing import Any
|
from typing import Any
|
||||||
@@ -11,111 +7,6 @@ from pydantic import BaseModel, ConfigDict, Field, model_validator
|
|||||||
from crewai.tools import BaseTool
|
from crewai.tools import BaseTool
|
||||||
|
|
||||||
|
|
||||||
def _fix_openai_config(config: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Fix deprecated OpenAI configuration parameters in a config dictionary."""
|
|
||||||
if not config:
|
|
||||||
return config
|
|
||||||
|
|
||||||
# Create a deep copy to avoid modifying the original
|
|
||||||
fixed_config = deepcopy(config)
|
|
||||||
|
|
||||||
def _is_azure_config(cfg: dict[str, Any]) -> bool:
|
|
||||||
"""Determine if this is an Azure OpenAI configuration."""
|
|
||||||
# Check for explicit Azure indicators
|
|
||||||
if cfg.get('provider') == 'azure_openai':
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check for Azure URLs in various fields
|
|
||||||
for field in ['openai_api_base', 'base_url', 'api_base']:
|
|
||||||
url = cfg.get(field, '')
|
|
||||||
if isinstance(url, str) and 'azure' in url.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
# Check if deployment is present (common in Azure configs)
|
|
||||||
# But only if we also have other Azure indicators
|
|
||||||
if 'deployment' in cfg:
|
|
||||||
for field in ['openai_api_base', 'base_url', 'api_base']:
|
|
||||||
url = cfg.get(field, '')
|
|
||||||
if isinstance(url, str) and 'azure' in url.lower():
|
|
||||||
return True
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
def _fix_config_recursively(cfg: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Recursively fix OpenAI config parameters."""
|
|
||||||
if not isinstance(cfg, dict):
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
# Only fix if this is definitely an Azure configuration
|
|
||||||
if _is_azure_config(cfg):
|
|
||||||
# Fix deprecated Azure OpenAI parameters
|
|
||||||
if 'openai_api_base' in cfg and 'azure_endpoint' not in cfg:
|
|
||||||
cfg['azure_endpoint'] = cfg.pop('openai_api_base')
|
|
||||||
|
|
||||||
if 'base_url' in cfg and 'azure_endpoint' not in cfg:
|
|
||||||
# Only convert base_url to azure_endpoint for Azure URLs
|
|
||||||
base_url = cfg.get('base_url', '')
|
|
||||||
if 'openai.azure.com' in base_url:
|
|
||||||
cfg['azure_endpoint'] = cfg.pop('base_url')
|
|
||||||
|
|
||||||
# Handle deployment -> azure_deployment conversion for Azure configs
|
|
||||||
if 'deployment' in cfg and 'azure_deployment' not in cfg:
|
|
||||||
cfg['azure_deployment'] = cfg.pop('deployment')
|
|
||||||
|
|
||||||
# For non-Azure configs, we might still need to handle some deprecated parameters
|
|
||||||
# but we should NOT convert them to Azure format
|
|
||||||
else:
|
|
||||||
# For regular OpenAI configs, just remove the deprecated openai_api_base if present
|
|
||||||
# since it's not valid for regular OpenAI (should use base_url instead)
|
|
||||||
if 'openai_api_base' in cfg and 'base_url' not in cfg:
|
|
||||||
# Only convert to base_url if it's NOT an Azure URL
|
|
||||||
api_base = cfg.get('openai_api_base', '')
|
|
||||||
if isinstance(api_base, str) and 'openai.azure.com' not in api_base:
|
|
||||||
cfg['base_url'] = cfg.pop('openai_api_base')
|
|
||||||
|
|
||||||
# Recursively fix nested dictionaries
|
|
||||||
for key, value in cfg.items():
|
|
||||||
if isinstance(value, dict):
|
|
||||||
cfg[key] = _fix_config_recursively(value)
|
|
||||||
|
|
||||||
return cfg
|
|
||||||
|
|
||||||
return _fix_config_recursively(fixed_config)
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
|
||||||
def _temporarily_unset_env_vars():
|
|
||||||
"""Temporarily unset problematic environment variables that cause OpenAI validation issues."""
|
|
||||||
problematic_vars = [
|
|
||||||
'AZURE_API_BASE',
|
|
||||||
'OPENAI_API_BASE',
|
|
||||||
'OPENAI_BASE_URL'
|
|
||||||
]
|
|
||||||
|
|
||||||
# Store original values
|
|
||||||
original_values = {}
|
|
||||||
for var in problematic_vars:
|
|
||||||
if var in os.environ:
|
|
||||||
original_values[var] = os.environ[var]
|
|
||||||
del os.environ[var]
|
|
||||||
|
|
||||||
try:
|
|
||||||
# Set the correct Azure environment variables if we had AZURE_API_BASE
|
|
||||||
if 'AZURE_API_BASE' in original_values:
|
|
||||||
os.environ['AZURE_OPENAI_ENDPOINT'] = original_values['AZURE_API_BASE']
|
|
||||||
|
|
||||||
yield
|
|
||||||
finally:
|
|
||||||
# Restore original values
|
|
||||||
for var, value in original_values.items():
|
|
||||||
os.environ[var] = value
|
|
||||||
|
|
||||||
# Clean up the temporary Azure endpoint we set
|
|
||||||
if 'AZURE_API_BASE' in original_values and 'AZURE_OPENAI_ENDPOINT' in os.environ:
|
|
||||||
if os.environ['AZURE_OPENAI_ENDPOINT'] == original_values['AZURE_API_BASE']:
|
|
||||||
del os.environ['AZURE_OPENAI_ENDPOINT']
|
|
||||||
|
|
||||||
|
|
||||||
class Adapter(BaseModel, ABC):
|
class Adapter(BaseModel, ABC):
|
||||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
@@ -153,11 +44,7 @@ class RagTool(BaseTool):
|
|||||||
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
from crewai_tools.adapters.embedchain_adapter import EmbedchainAdapter
|
||||||
|
|
||||||
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
with portalocker.Lock("crewai-rag-tool.lock", timeout=10):
|
||||||
# Fix both environment variables and config parameters
|
app = App.from_config(config=self.config) if self.config else App()
|
||||||
with _temporarily_unset_env_vars():
|
|
||||||
# Fix deprecated OpenAI parameters in config
|
|
||||||
fixed_config = _fix_openai_config(self.config)
|
|
||||||
app = App.from_config(config=fixed_config) if fixed_config else App()
|
|
||||||
|
|
||||||
self.adapter = EmbedchainAdapter(
|
self.adapter = EmbedchainAdapter(
|
||||||
embedchain_app=app, summarize=self.summarize
|
embedchain_app=app, summarize=self.summarize
|
||||||
|
|||||||
Reference in New Issue
Block a user