diff --git a/pyproject.toml b/pyproject.toml index a45a52da4..f15e21a0f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -69,6 +69,9 @@ docling = [ aisuite = [ "aisuite>=0.1.10", ] +mlflow = [ + "mlflow>=2.19.0", +] [tool.uv] dev-dependencies = [ diff --git a/src/crewai/integrations/__init__.py b/src/crewai/integrations/__init__.py new file mode 100644 index 000000000..6b37f9611 --- /dev/null +++ b/src/crewai/integrations/__init__.py @@ -0,0 +1 @@ +"""CrewAI integrations with external tools and services.""" diff --git a/src/crewai/integrations/mlflow.py b/src/crewai/integrations/mlflow.py new file mode 100644 index 000000000..352805749 --- /dev/null +++ b/src/crewai/integrations/mlflow.py @@ -0,0 +1,92 @@ +"""MLFlow integration utilities for CrewAI.""" + +import logging +from typing import Optional + +logger = logging.getLogger(__name__) + + +def is_mlflow_available() -> bool: + """Check if MLFlow is available.""" + try: + import mlflow + return True + except ImportError: + return False + + +def setup_mlflow_autolog( + log_traces: bool = True, + log_models: bool = False, + disable: bool = False, + exclusive: bool = False, + disable_for_unsupported_versions: bool = False, + silent: bool = False, +) -> bool: + """ + Setup MLFlow autologging for CrewAI. + + This is a convenience wrapper around mlflow.crewai.autolog() that provides + better error handling and documentation. + + Args: + log_traces: Whether to log traces + log_models: Whether to log models + disable: Whether to disable autologging + exclusive: Whether to use exclusive mode + disable_for_unsupported_versions: Whether to disable for unsupported versions + silent: Whether to suppress warnings + + Returns: + True if autologging was successfully enabled, False otherwise + """ + if not is_mlflow_available(): + if not silent: + logger.warning( + "MLFlow is not available. Install it with: pip install mlflow" + ) + return False + + try: + import mlflow + mlflow.crewai.autolog( + log_traces=log_traces, + log_models=log_models, + disable=disable, + exclusive=exclusive, + disable_for_unsupported_versions=disable_for_unsupported_versions, + silent=silent, + ) + if not silent: + logger.info("MLFlow autologging enabled for CrewAI") + return True + except Exception as e: + if not silent: + logger.error(f"Failed to enable MLFlow autologging: {e}") + return False + + +def get_active_run(): + """Get the active MLFlow run if available.""" + if not is_mlflow_available(): + return None + + try: + import mlflow + return mlflow.active_run() + except Exception: + return None + + +def log_crew_execution(crew_name: str, **kwargs): + """Log crew execution details to MLFlow if available.""" + if not is_mlflow_available(): + return + + try: + import mlflow + with mlflow.start_run(run_name=f"crew_{crew_name}"): + for key, value in kwargs.items(): + mlflow.log_param(key, value) + except Exception as e: + logger.debug(f"Failed to log crew execution to MLFlow: {e}") diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py new file mode 100644 index 000000000..9c3c84bbe --- /dev/null +++ b/tests/integrations/__init__.py @@ -0,0 +1 @@ +"""Tests for CrewAI integrations.""" diff --git a/tests/integrations/test_mlflow.py b/tests/integrations/test_mlflow.py new file mode 100644 index 000000000..d8212a989 --- /dev/null +++ b/tests/integrations/test_mlflow.py @@ -0,0 +1,207 @@ +"""Tests for MLFlow integration.""" + +import pytest +from unittest.mock import Mock, patch, MagicMock +import sys + + +class TestMLFlowIntegration: + """Test MLFlow integration functionality.""" + + def test_is_mlflow_available_when_installed(self): + """Test that is_mlflow_available returns True when MLFlow is installed.""" + with patch.dict(sys.modules, {'mlflow': Mock()}): + from crewai.integrations.mlflow import is_mlflow_available + assert is_mlflow_available() is True + + def test_is_mlflow_available_when_not_installed(self): + """Test that is_mlflow_available returns False when MLFlow is not installed.""" + with patch.dict(sys.modules, {'mlflow': None}): + with patch('crewai.integrations.mlflow.logger'): + from crewai.integrations.mlflow import is_mlflow_available + assert is_mlflow_available() is False + + def test_setup_mlflow_autolog_success(self): + """Test successful MLFlow autolog setup.""" + mock_mlflow = Mock() + mock_mlflow.crewai.autolog = Mock() + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + from crewai.integrations.mlflow import setup_mlflow_autolog + result = setup_mlflow_autolog() + + assert result is True + mock_mlflow.crewai.autolog.assert_called_once_with( + log_traces=True, + log_models=False, + disable=False, + exclusive=False, + disable_for_unsupported_versions=False, + silent=False, + ) + + def test_setup_mlflow_autolog_not_available(self): + """Test MLFlow autolog setup when MLFlow is not available.""" + with patch.dict(sys.modules, {'mlflow': None}): + with patch('crewai.integrations.mlflow.logger') as mock_logger: + from crewai.integrations.mlflow import setup_mlflow_autolog + result = setup_mlflow_autolog() + + assert result is False + mock_logger.warning.assert_called_once() + + def test_setup_mlflow_autolog_silent_mode(self): + """Test MLFlow autolog setup in silent mode.""" + with patch.dict(sys.modules, {'mlflow': None}): + with patch('crewai.integrations.mlflow.logger') as mock_logger: + from crewai.integrations.mlflow import setup_mlflow_autolog + result = setup_mlflow_autolog(silent=True) + + assert result is False + mock_logger.warning.assert_not_called() + + def test_setup_mlflow_autolog_exception(self): + """Test MLFlow autolog setup when an exception occurs.""" + mock_mlflow = Mock() + mock_mlflow.crewai.autolog.side_effect = Exception("Test error") + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + with patch('crewai.integrations.mlflow.logger') as mock_logger: + from crewai.integrations.mlflow import setup_mlflow_autolog + result = setup_mlflow_autolog() + + assert result is False + mock_logger.error.assert_called_once() + + def test_get_active_run_success(self): + """Test getting active MLFlow run.""" + mock_run = Mock() + mock_mlflow = Mock() + mock_mlflow.active_run.return_value = mock_run + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + from crewai.integrations.mlflow import get_active_run + result = get_active_run() + + assert result == mock_run + mock_mlflow.active_run.assert_called_once() + + def test_get_active_run_not_available(self): + """Test getting active MLFlow run when MLFlow is not available.""" + with patch.dict(sys.modules, {'mlflow': None}): + from crewai.integrations.mlflow import get_active_run + result = get_active_run() + + assert result is None + + def test_get_active_run_exception(self): + """Test getting active MLFlow run when an exception occurs.""" + mock_mlflow = Mock() + mock_mlflow.active_run.side_effect = Exception("Test error") + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + from crewai.integrations.mlflow import get_active_run + result = get_active_run() + + assert result is None + + def test_log_crew_execution_success(self): + """Test logging crew execution to MLFlow.""" + mock_mlflow = Mock() + mock_context_manager = Mock() + mock_mlflow.start_run.return_value.__enter__ = Mock(return_value=mock_context_manager) + mock_mlflow.start_run.return_value.__exit__ = Mock(return_value=None) + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + from crewai.integrations.mlflow import log_crew_execution + log_crew_execution("test_crew", param1="value1", param2="value2") + + mock_mlflow.start_run.assert_called_once_with(run_name="crew_test_crew") + assert mock_mlflow.log_param.call_count == 2 + + def test_log_crew_execution_not_available(self): + """Test logging crew execution when MLFlow is not available.""" + with patch.dict(sys.modules, {'mlflow': None}): + from crewai.integrations.mlflow import log_crew_execution + log_crew_execution("test_crew", param1="value1") + + def test_log_crew_execution_exception(self): + """Test logging crew execution when an exception occurs.""" + mock_mlflow = Mock() + mock_mlflow.start_run.side_effect = Exception("Test error") + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + with patch('crewai.integrations.mlflow.logger') as mock_logger: + from crewai.integrations.mlflow import log_crew_execution + log_crew_execution("test_crew", param1="value1") + mock_logger.debug.assert_called_once() + + +class TestMLFlowAutologIntegration: + """Test the actual MLFlow autolog integration.""" + + @pytest.mark.skipif( + not pytest.importorskip("mlflow", minversion="2.19.0"), + reason="MLFlow not available or version too old" + ) + def test_mlflow_crewai_autolog_exists(self): + """Test that mlflow.crewai.autolog exists and can be called.""" + try: + import mlflow + assert hasattr(mlflow, 'crewai') + assert hasattr(mlflow.crewai, 'autolog') + + mlflow.crewai.autolog(disable=True) # Disable to avoid side effects + + except ImportError: + pytest.skip("MLFlow not available") + except Exception as e: + pytest.fail(f"mlflow.crewai.autolog() failed: {e}") + + def test_mlflow_integration_with_crew(self): + """Test MLFlow integration with CrewAI Crew class.""" + mock_mlflow = Mock() + mock_mlflow.crewai.autolog = Mock() + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + from crewai.integrations.mlflow import setup_mlflow_autolog + from crewai import Crew, Agent, Task + + setup_mlflow_autolog() + + agent = Agent( + role="Test Agent", + goal="Test goal", + backstory="Test backstory" + ) + + task = Task( + description="Test task", + expected_output="Test output", + agent=agent + ) + + crew = Crew( + agents=[agent], + tasks=[task] + ) + + assert crew is not None + assert len(crew.agents) == 1 + assert len(crew.tasks) == 1 + + def test_documentation_example(self): + """Test the example from the documentation.""" + mock_mlflow = Mock() + mock_mlflow.crewai.autolog = Mock() + mock_mlflow.set_tracking_uri = Mock() + + with patch.dict(sys.modules, {'mlflow': mock_mlflow}): + import mlflow + + mlflow.set_tracking_uri("http://localhost:5000") + + mlflow.crewai.autolog() + + mock_mlflow.set_tracking_uri.assert_called_once_with("http://localhost:5000") + mock_mlflow.crewai.autolog.assert_called_once()