diff --git a/src/crewai/llm.py b/src/crewai/llm.py index da7c2991c..26e03ca79 100644 --- a/src/crewai/llm.py +++ b/src/crewai/llm.py @@ -372,16 +372,8 @@ class LLM(BaseLLM): Returns: bool: True if the model is from Ollama, False otherwise. """ - # Check if model starts with ollama/ prefix - if model.startswith("ollama/"): - return True - - # Check if the provider extracted from the model is ollama - if "/" in model: - provider = model.split("/")[0] - return provider == "ollama" - - return False + OLLAMA_IDENTIFIERS = ("ollama/", "ollama:") + return any(identifier in model.lower() for identifier in OLLAMA_IDENTIFIERS) def _prepare_completion_params( self, @@ -430,7 +422,9 @@ class LLM(BaseLLM): **self.additional_params, } - if not self._is_ollama_model(self.model): + if self._is_ollama_model(self.model): + params.pop("response_format", None) # Remove safely if exists + else: params["response_format"] = self.response_format # Remove None values from params @@ -1091,7 +1085,7 @@ class LLM(BaseLLM): if self._is_ollama_model(self.model): return - provider = self._get_custom_llm_provider() + provider: str = self._get_custom_llm_provider() if self.response_format is not None and not supports_response_schema( model=self.model, custom_llm_provider=provider, diff --git a/test_ollama_fix.py b/test_ollama_fix.py index e22e34365..96769641b 100644 --- a/test_ollama_fix.py +++ b/test_ollama_fix.py @@ -33,6 +33,8 @@ def test_original_issue(): ) print("✅ Agent creation with Ollama LLM succeeded") + assert agent.llm.model == "ollama/gemma3:latest" + except ValueError as e: if "does not support response_format" in str(e): print(f"❌ Original issue still exists: {e}") diff --git a/tests/agent_test.py b/tests/agent_test.py index 70174802e..96421d707 100644 --- a/tests/agent_test.py +++ b/tests/agent_test.py @@ -1691,7 +1691,17 @@ def test_agent_execute_task_with_ollama(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_ollama_model_with_response_format(): - """Test that Ollama models work correctly when response_format is provided.""" + """ + Test Ollama model compatibility with response_format parameter. + + Verifies: + - LLM initialization with response_format doesn't raise ValueError + - Agent creation with formatted LLM succeeds + - Successful execution without raising ValueError for unsupported response_format + + Note: This test may fail in CI due to Ollama server not being available, + but the core functionality (no ValueError on initialization) should work. + """ from pydantic import BaseModel class TestOutput(BaseModel): @@ -1719,7 +1729,14 @@ def test_ollama_model_with_response_format(): @pytest.mark.vcr(filter_headers=["authorization"]) def test_ollama_model_response_format_filtered_in_params(): - """Test that response_format is filtered out for Ollama models in _prepare_completion_params.""" + """ + Test that response_format is filtered out for Ollama models in _prepare_completion_params. + + Verifies: + - Ollama model detection works correctly for various model formats + - response_format parameter is excluded from completion params for Ollama models + - Model detection returns correct boolean values for different model types + """ from pydantic import BaseModel class TestOutput(BaseModel): @@ -1739,7 +1756,14 @@ def test_ollama_model_response_format_filtered_in_params(): def test_non_ollama_model_keeps_response_format(): - """Test that non-Ollama models still include response_format in params.""" + """ + Test that non-Ollama models still include response_format in params. + + Verifies: + - Non-Ollama models are correctly identified as such + - response_format parameter is preserved for non-Ollama models + - Backward compatibility is maintained for existing LLM providers + """ from pydantic import BaseModel class TestOutput(BaseModel): @@ -1756,6 +1780,35 @@ def test_non_ollama_model_keeps_response_format(): assert params.get("response_format") == TestOutput +def test_ollama_model_detection_edge_cases(): + """ + Test edge cases for Ollama model detection. + + Verifies: + - Various Ollama model naming patterns are correctly identified + - Case-insensitive detection works properly + - Non-Ollama models containing 'ollama' in name are not misidentified + - Different provider prefixes are handled correctly + """ + from crewai.llm import LLM + + test_cases = [ + ("ollama/llama3.2:3b", True, "Standard ollama/ prefix"), + ("OLLAMA/MODEL:TAG", True, "Uppercase ollama/ prefix"), + ("ollama:custom-model", True, "ollama: prefix"), + ("custom/ollama-model", False, "Contains 'ollama' but not prefix"), + ("gpt-4", False, "Non-Ollama model"), + ("anthropic/claude-3", False, "Different provider"), + ("openai/gpt-4", False, "OpenAI model"), + ("ollama/gemma3:latest", True, "Ollama with version tag"), + ] + + for model_name, expected, description in test_cases: + llm = LLM(model=model_name) + result = llm._is_ollama_model(model_name) + assert result == expected, f"Failed for {description}: {model_name} -> {result} (expected {expected})" + + @pytest.mark.vcr(filter_headers=["authorization"]) def test_agent_with_knowledge_sources(): content = "Brandon's favorite color is red and he likes Mexican food." diff --git a/tests/test_ollama_integration.py b/tests/test_ollama_integration.py new file mode 100644 index 000000000..84d12bca7 --- /dev/null +++ b/tests/test_ollama_integration.py @@ -0,0 +1,106 @@ +""" +Integration tests for Ollama model handling. +This module tests the Ollama-specific functionality including response_format handling. +""" + +from pydantic import BaseModel +from crewai.llm import LLM +from crewai import Agent + +class GuideOutline(BaseModel): + title: str + sections: list[str] + +def test_original_issue(): + """Test the original issue scenario from GitHub issue #3082.""" + print("Testing original issue scenario...") + + try: + llm = LLM(model="ollama/gemma3:latest", response_format=GuideOutline) + print("✅ LLM creation with response_format succeeded") + + params = llm._prepare_completion_params("Test message") + if "response_format" not in params or params.get("response_format") is None: + print("✅ response_format correctly filtered out for Ollama model") + else: + print("❌ response_format was not filtered out") + + agent = Agent( + role="Guide Creator", + goal="Create comprehensive guides", + backstory="You are an expert at creating structured guides", + llm=llm + ) + print("✅ Agent creation with Ollama LLM succeeded") + + assert agent.llm.model == "ollama/gemma3:latest" + + except ValueError as e: + if "does not support response_format" in str(e): + print(f"❌ Original issue still exists: {e}") + return False + else: + print(f"❌ Unexpected ValueError: {e}") + return False + except Exception as e: + print(f"❌ Unexpected error: {e}") + return False + + return True + +def test_non_ollama_models(): + """Test that non-Ollama models still work with response_format.""" + print("\nTesting non-Ollama models...") + + try: + llm = LLM(model="gpt-4", response_format=GuideOutline) + params = llm._prepare_completion_params("Test message") + + if params.get("response_format") == GuideOutline: + print("✅ Non-Ollama models still include response_format") + return True + else: + print("❌ Non-Ollama models missing response_format") + return False + + except Exception as e: + print(f"❌ Error with non-Ollama model: {e}") + return False + +def test_ollama_model_detection_edge_cases(): + """Test edge cases for Ollama model detection.""" + print("\nTesting Ollama model detection edge cases...") + + test_cases = [ + ("ollama/llama3.2:3b", True, "Standard ollama/ prefix"), + ("OLLAMA/MODEL:TAG", True, "Uppercase ollama/ prefix"), + ("ollama:custom-model", True, "ollama: prefix"), + ("custom/ollama-model", False, "Contains 'ollama' but not prefix"), + ("gpt-4", False, "Non-Ollama model"), + ("anthropic/claude-3", False, "Different provider"), + ("openai/gpt-4", False, "OpenAI model"), + ] + + all_passed = True + for model, expected, description in test_cases: + llm = LLM(model=model) + result = llm._is_ollama_model(model) + if result == expected: + print(f"✅ {description}: {model} -> {result}") + else: + print(f"❌ {description}: {model} -> {result} (expected {expected})") + all_passed = False + + return all_passed + +if __name__ == "__main__": + print("Testing Ollama response_format fix...") + + success1 = test_original_issue() + success2 = test_non_ollama_models() + success3 = test_ollama_model_detection_edge_cases() + + if success1 and success2 and success3: + print("\n🎉 All tests passed! The fix is working correctly.") + else: + print("\n💥 Some tests failed. The fix needs more work.")