Address code review feedback: improve model detection, parameter filtering, and test coverage

- Refactor _is_ollama_model to use constants for better maintainability
- Make parameter filtering more explicit with clear comments
- Add type hints for better code clarity
- Add comprehensive edge case tests for model detection
- Improve test docstrings with detailed descriptions
- Move integration test to proper tests/ directory structure
- Fix lint error in test script by adding assertion
- All tests passing locally with improved code quality

Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
Devin AI
2025-06-28 21:40:38 +00:00
parent aa82ca5273
commit 7a19bfb4a9
4 changed files with 170 additions and 15 deletions

View File

@@ -372,16 +372,8 @@ class LLM(BaseLLM):
Returns:
bool: True if the model is from Ollama, False otherwise.
"""
# Check if model starts with ollama/ prefix
if model.startswith("ollama/"):
return True
# Check if the provider extracted from the model is ollama
if "/" in model:
provider = model.split("/")[0]
return provider == "ollama"
return False
OLLAMA_IDENTIFIERS = ("ollama/", "ollama:")
return any(identifier in model.lower() for identifier in OLLAMA_IDENTIFIERS)
def _prepare_completion_params(
self,
@@ -430,7 +422,9 @@ class LLM(BaseLLM):
**self.additional_params,
}
if not self._is_ollama_model(self.model):
if self._is_ollama_model(self.model):
params.pop("response_format", None) # Remove safely if exists
else:
params["response_format"] = self.response_format
# Remove None values from params
@@ -1091,7 +1085,7 @@ class LLM(BaseLLM):
if self._is_ollama_model(self.model):
return
provider = self._get_custom_llm_provider()
provider: str = self._get_custom_llm_provider()
if self.response_format is not None and not supports_response_schema(
model=self.model,
custom_llm_provider=provider,

View File

@@ -33,6 +33,8 @@ def test_original_issue():
)
print("✅ Agent creation with Ollama LLM succeeded")
assert agent.llm.model == "ollama/gemma3:latest"
except ValueError as e:
if "does not support response_format" in str(e):
print(f"❌ Original issue still exists: {e}")

View File

@@ -1691,7 +1691,17 @@ def test_agent_execute_task_with_ollama():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_ollama_model_with_response_format():
"""Test that Ollama models work correctly when response_format is provided."""
"""
Test Ollama model compatibility with response_format parameter.
Verifies:
- LLM initialization with response_format doesn't raise ValueError
- Agent creation with formatted LLM succeeds
- Successful execution without raising ValueError for unsupported response_format
Note: This test may fail in CI due to Ollama server not being available,
but the core functionality (no ValueError on initialization) should work.
"""
from pydantic import BaseModel
class TestOutput(BaseModel):
@@ -1719,7 +1729,14 @@ def test_ollama_model_with_response_format():
@pytest.mark.vcr(filter_headers=["authorization"])
def test_ollama_model_response_format_filtered_in_params():
"""Test that response_format is filtered out for Ollama models in _prepare_completion_params."""
"""
Test that response_format is filtered out for Ollama models in _prepare_completion_params.
Verifies:
- Ollama model detection works correctly for various model formats
- response_format parameter is excluded from completion params for Ollama models
- Model detection returns correct boolean values for different model types
"""
from pydantic import BaseModel
class TestOutput(BaseModel):
@@ -1739,7 +1756,14 @@ def test_ollama_model_response_format_filtered_in_params():
def test_non_ollama_model_keeps_response_format():
"""Test that non-Ollama models still include response_format in params."""
"""
Test that non-Ollama models still include response_format in params.
Verifies:
- Non-Ollama models are correctly identified as such
- response_format parameter is preserved for non-Ollama models
- Backward compatibility is maintained for existing LLM providers
"""
from pydantic import BaseModel
class TestOutput(BaseModel):
@@ -1756,6 +1780,35 @@ def test_non_ollama_model_keeps_response_format():
assert params.get("response_format") == TestOutput
def test_ollama_model_detection_edge_cases():
"""
Test edge cases for Ollama model detection.
Verifies:
- Various Ollama model naming patterns are correctly identified
- Case-insensitive detection works properly
- Non-Ollama models containing 'ollama' in name are not misidentified
- Different provider prefixes are handled correctly
"""
from crewai.llm import LLM
test_cases = [
("ollama/llama3.2:3b", True, "Standard ollama/ prefix"),
("OLLAMA/MODEL:TAG", True, "Uppercase ollama/ prefix"),
("ollama:custom-model", True, "ollama: prefix"),
("custom/ollama-model", False, "Contains 'ollama' but not prefix"),
("gpt-4", False, "Non-Ollama model"),
("anthropic/claude-3", False, "Different provider"),
("openai/gpt-4", False, "OpenAI model"),
("ollama/gemma3:latest", True, "Ollama with version tag"),
]
for model_name, expected, description in test_cases:
llm = LLM(model=model_name)
result = llm._is_ollama_model(model_name)
assert result == expected, f"Failed for {description}: {model_name} -> {result} (expected {expected})"
@pytest.mark.vcr(filter_headers=["authorization"])
def test_agent_with_knowledge_sources():
content = "Brandon's favorite color is red and he likes Mexican food."

View File

@@ -0,0 +1,106 @@
"""
Integration tests for Ollama model handling.
This module tests the Ollama-specific functionality including response_format handling.
"""
from pydantic import BaseModel
from crewai.llm import LLM
from crewai import Agent
class GuideOutline(BaseModel):
title: str
sections: list[str]
def test_original_issue():
"""Test the original issue scenario from GitHub issue #3082."""
print("Testing original issue scenario...")
try:
llm = LLM(model="ollama/gemma3:latest", response_format=GuideOutline)
print("✅ LLM creation with response_format succeeded")
params = llm._prepare_completion_params("Test message")
if "response_format" not in params or params.get("response_format") is None:
print("✅ response_format correctly filtered out for Ollama model")
else:
print("❌ response_format was not filtered out")
agent = Agent(
role="Guide Creator",
goal="Create comprehensive guides",
backstory="You are an expert at creating structured guides",
llm=llm
)
print("✅ Agent creation with Ollama LLM succeeded")
assert agent.llm.model == "ollama/gemma3:latest"
except ValueError as e:
if "does not support response_format" in str(e):
print(f"❌ Original issue still exists: {e}")
return False
else:
print(f"❌ Unexpected ValueError: {e}")
return False
except Exception as e:
print(f"❌ Unexpected error: {e}")
return False
return True
def test_non_ollama_models():
"""Test that non-Ollama models still work with response_format."""
print("\nTesting non-Ollama models...")
try:
llm = LLM(model="gpt-4", response_format=GuideOutline)
params = llm._prepare_completion_params("Test message")
if params.get("response_format") == GuideOutline:
print("✅ Non-Ollama models still include response_format")
return True
else:
print("❌ Non-Ollama models missing response_format")
return False
except Exception as e:
print(f"❌ Error with non-Ollama model: {e}")
return False
def test_ollama_model_detection_edge_cases():
"""Test edge cases for Ollama model detection."""
print("\nTesting Ollama model detection edge cases...")
test_cases = [
("ollama/llama3.2:3b", True, "Standard ollama/ prefix"),
("OLLAMA/MODEL:TAG", True, "Uppercase ollama/ prefix"),
("ollama:custom-model", True, "ollama: prefix"),
("custom/ollama-model", False, "Contains 'ollama' but not prefix"),
("gpt-4", False, "Non-Ollama model"),
("anthropic/claude-3", False, "Different provider"),
("openai/gpt-4", False, "OpenAI model"),
]
all_passed = True
for model, expected, description in test_cases:
llm = LLM(model=model)
result = llm._is_ollama_model(model)
if result == expected:
print(f"{description}: {model} -> {result}")
else:
print(f"{description}: {model} -> {result} (expected {expected})")
all_passed = False
return all_passed
if __name__ == "__main__":
print("Testing Ollama response_format fix...")
success1 = test_original_issue()
success2 = test_non_ollama_models()
success3 = test_ollama_model_detection_edge_cases()
if success1 and success2 and success3:
print("\n🎉 All tests passed! The fix is working correctly.")
else:
print("\n💥 Some tests failed. The fix needs more work.")