diff --git a/src/crewai/utilities/import_utils.py b/src/crewai/utilities/import_utils.py index e6d807c36..e26698068 100644 --- a/src/crewai/utilities/import_utils.py +++ b/src/crewai/utilities/import_utils.py @@ -2,29 +2,94 @@ import importlib from types import ModuleType +from typing import Annotated, Any, TypeAlias + +from pydantic import AfterValidator, TypeAdapter +from typing_extensions import deprecated +@deprecated( + "Not needed when using `crewai.utilities.import_utils.import_and_validate_definition`" +) class OptionalDependencyError(ImportError): """Exception raised when an optional dependency is not installed.""" -def require(name: str, *, purpose: str) -> ModuleType: - """Import a module, raising a helpful error if it's not installed. +@deprecated( + "Use `crewai.utilities.import_utils.import_and_validate_definition` instead." +) +def require(name: str, *, purpose: str, attr: str | None = None) -> ModuleType | Any: + """Import a module, optionally returning a specific attribute. Args: name: The module name to import. purpose: Description of what requires this dependency. + attr: Optional attribute name to get from the module. Returns: - The imported module. + The imported module or the specified attribute. Raises: OptionalDependencyError: If the module is not installed. + AttributeError: If the specified attribute doesn't exist. """ try: - return importlib.import_module(name) + module = importlib.import_module(name) + if attr is not None: + return getattr(module, attr) + return module except ImportError as exc: + package_name = name.split(".")[0] raise OptionalDependencyError( f"{purpose} requires the optional dependency '{name}'.\n" - f"Install it with: uv add {name}" + f"Install it with: uv add {package_name}" ) from exc + except AttributeError as exc: + raise AttributeError(f"Module '{name}' has no attribute '{attr}'") from exc + + +def validate_import_path(v: str) -> Any: + """Import and return the class/function from the import path. + + Args: + v: Import path string in the format 'module.path.ClassName'. + + Returns: + The imported class or function. + + Raises: + ValueError: If the import path is malformed or the module cannot be imported. + """ + module_path, _, attr = v.rpartition(".") + if not module_path or not attr: + raise ValueError(f"import_path '{v}' must be of the form 'module.ClassName'") + + try: + mod = importlib.import_module(module_path) + except ImportError as exc: + parts = module_path.split(".") + if not parts: + raise ValueError(f"Malformed import path: '{v}'") from exc + package = parts[0] + raise ValueError( + f"Package '{package}' could not be imported. Install it with: uv add {package}" + ) from exc + + if not hasattr(mod, attr): + raise ValueError(f"Attribute '{attr}' not found in module '{module_path}'") + return getattr(mod, attr) + + +ImportedDefinition: TypeAlias = Annotated[Any, AfterValidator(validate_import_path)] +adapter = TypeAdapter(ImportedDefinition) + + +def import_and_validate_definition(v: str) -> Any: + """Pydantic-compatible function to import a class/function from a string path. + + Args: + v: Import path string in the format 'module.path.ClassName'. + Returns: + The imported class or function + """ + return adapter.validate_python(v) diff --git a/tests/utilities/test_import_utils.py b/tests/utilities/test_import_utils.py index 535403aa4..29738172c 100644 --- a/tests/utilities/test_import_utils.py +++ b/tests/utilities/test_import_utils.py @@ -1,9 +1,16 @@ """Tests for import utilities.""" -import pytest -from unittest.mock import patch +import sys +from unittest.mock import MagicMock, patch -from crewai.utilities.import_utils import require, OptionalDependencyError +import pytest + +from crewai.utilities.import_utils import ( + OptionalDependencyError, + import_and_validate_definition, + require, + validate_import_path, +) class TestRequire: @@ -40,3 +47,143 @@ class TestRequire: def test_optional_dependency_error_is_import_error(self): """Test that OptionalDependencyError is a subclass of ImportError.""" assert issubclass(OptionalDependencyError, ImportError) + + def test_require_with_attr(self): + """Test requiring a specific attribute from a module.""" + loads = require("json", purpose="testing", attr="loads") + import json + + assert loads == json.loads + + def test_require_with_nonexistent_attr(self): + """Test requiring a nonexistent attribute raises AttributeError.""" + with pytest.raises(AttributeError) as exc_info: + require("json", purpose="testing", attr="nonexistent_attr") + + assert "Module 'json' has no attribute 'nonexistent_attr'" in str( + exc_info.value + ) + + def test_require_extracts_package_name(self): + """Test that require correctly extracts package name from module path.""" + with pytest.raises(OptionalDependencyError) as exc_info: + require("some.nested.module.path", purpose="testing") + + error_msg = str(exc_info.value) + assert "uv add some" in error_msg + + +class TestValidateImportPath: + """Test the validate_import_path function.""" + + def test_validate_import_path_success(self): + """Test successful import of a class.""" + result = validate_import_path("json.JSONDecoder") + import json + + assert result == json.JSONDecoder + + def test_validate_import_path_malformed_no_module(self): + """Test validation with no module path.""" + with pytest.raises(ValueError) as exc_info: + validate_import_path("ClassName") + + assert "import_path 'ClassName' must be of the form 'module.ClassName'" in str( + exc_info.value + ) + + def test_validate_import_path_empty_string(self): + """Test validation with empty string.""" + with pytest.raises(ValueError) as exc_info: + validate_import_path("") + + assert "import_path '' must be of the form 'module.ClassName'" in str( + exc_info.value + ) + + def test_validate_import_path_module_not_found(self): + """Test validation with non-existent module.""" + with pytest.raises(ValueError) as exc_info: + validate_import_path("nonexistent_module.ClassName") + + error_msg = str(exc_info.value) + assert "Package 'nonexistent_module' could not be imported" in error_msg + assert "uv add nonexistent_module" in error_msg + + def test_validate_import_path_attribute_not_found(self): + """Test validation when attribute doesn't exist in module.""" + with pytest.raises(ValueError) as exc_info: + validate_import_path("json.NonExistentClass") + + assert "Attribute 'NonExistentClass' not found in module 'json'" in str( + exc_info.value + ) + + def test_validate_import_path_nested_module(self): + """Test validation with nested module path.""" + result = validate_import_path("unittest.mock.MagicMock") + from unittest.mock import MagicMock + + assert result == MagicMock + + def test_validate_import_path_extracts_package_name(self): + """Test that package name is correctly extracted for error message.""" + with pytest.raises(ValueError) as exc_info: + validate_import_path("some.nested.module.path.ClassName") + + error_msg = str(exc_info.value) + assert "Package 'some' could not be imported" in error_msg + assert "uv add some" in error_msg + + +class TestImportAndValidateDefinition: + """Test the import_and_validate_definition function.""" + + def test_import_and_validate_definition_success(self): + """Test successful import through Pydantic adapter.""" + result = import_and_validate_definition("json.JSONEncoder") + import json + + assert result == json.JSONEncoder + + def test_import_and_validate_definition_with_function(self): + """Test importing a function instead of a class.""" + result = import_and_validate_definition("json.loads") + import json + + assert result == json.loads + + def test_import_and_validate_definition_invalid(self): + """Test that invalid paths raise ValueError.""" + with pytest.raises(ValueError) as exc_info: + import_and_validate_definition("InvalidPath") + + assert "must be of the form 'module.ClassName'" in str(exc_info.value) + + def test_import_and_validate_definition_module_error(self): + """Test error handling for missing modules.""" + with pytest.raises(ValueError) as exc_info: + import_and_validate_definition("missing_package.SomeClass") + + error_msg = str(exc_info.value) + assert "Package 'missing_package' could not be imported" in error_msg + assert "uv add missing_package" in error_msg + + def test_import_and_validate_definition_attribute_error(self): + """Test error handling for missing attributes.""" + with pytest.raises(ValueError) as exc_info: + import_and_validate_definition("json.MissingClass") + + assert "Attribute 'MissingClass' not found in module 'json'" in str( + exc_info.value + ) + + def test_import_and_validate_definition_with_mock(self): + """Test that mocked modules work correctly.""" + mock_module = MagicMock() + mock_class = MagicMock() + mock_module.MockClass = mock_class + + with patch.dict(sys.modules, {"mocked_module": mock_module}): + result = import_and_validate_definition("mocked_module.MockClass") + assert result == mock_class