mirror of
https://github.com/crewAIInc/crewAI.git
synced 2026-01-09 16:18:30 +00:00
Fix Ollama models ValueError when response_format is used
- Add _is_ollama_model method to detect Ollama models consistently - Skip response_format validation for Ollama models in _validate_call_params - Filter out response_format parameter for Ollama models in _prepare_completion_params - Add comprehensive tests for Ollama response_format handling - Maintain backward compatibility for other LLM providers Fixes #3082 Co-Authored-By: João <joao@crewai.com>
This commit is contained in:
@@ -363,6 +363,26 @@ class LLM(BaseLLM):
|
||||
ANTHROPIC_PREFIXES = ("anthropic/", "claude-", "claude/")
|
||||
return any(prefix in model.lower() for prefix in ANTHROPIC_PREFIXES)
|
||||
|
||||
def _is_ollama_model(self, model: str) -> bool:
|
||||
"""Determine if the model is from Ollama provider.
|
||||
|
||||
Args:
|
||||
model: The model identifier string.
|
||||
|
||||
Returns:
|
||||
bool: True if the model is from Ollama, False otherwise.
|
||||
"""
|
||||
# Check if model starts with ollama/ prefix
|
||||
if model.startswith("ollama/"):
|
||||
return True
|
||||
|
||||
# Check if the provider extracted from the model is ollama
|
||||
if "/" in model:
|
||||
provider = model.split("/")[0]
|
||||
return provider == "ollama"
|
||||
|
||||
return False
|
||||
|
||||
def _prepare_completion_params(
|
||||
self,
|
||||
messages: Union[str, List[Dict[str, str]]],
|
||||
@@ -397,7 +417,6 @@ class LLM(BaseLLM):
|
||||
"presence_penalty": self.presence_penalty,
|
||||
"frequency_penalty": self.frequency_penalty,
|
||||
"logit_bias": self.logit_bias,
|
||||
"response_format": self.response_format,
|
||||
"seed": self.seed,
|
||||
"logprobs": self.logprobs,
|
||||
"top_logprobs": self.top_logprobs,
|
||||
@@ -411,6 +430,9 @@ class LLM(BaseLLM):
|
||||
**self.additional_params,
|
||||
}
|
||||
|
||||
if not self._is_ollama_model(self.model):
|
||||
params["response_format"] = self.response_format
|
||||
|
||||
# Remove None values from params
|
||||
return {k: v for k, v in params.items() if v is not None}
|
||||
|
||||
@@ -1065,6 +1087,10 @@ class LLM(BaseLLM):
|
||||
- "gemini/gemini-1.5-pro" yields "gemini"
|
||||
- If no slash is present, "openai" is assumed.
|
||||
"""
|
||||
# Skip validation for Ollama models as they don't support response_format
|
||||
if self._is_ollama_model(self.model):
|
||||
return
|
||||
|
||||
provider = self._get_custom_llm_provider()
|
||||
if self.response_format is not None and not supports_response_schema(
|
||||
model=self.model,
|
||||
|
||||
77
test_ollama_fix.py
Normal file
77
test_ollama_fix.py
Normal file
@@ -0,0 +1,77 @@
|
||||
"""
|
||||
Reproduction script for issue #3082 - Ollama response_format error.
|
||||
This script reproduces the original issue and verifies the fix.
|
||||
"""
|
||||
|
||||
from pydantic import BaseModel
|
||||
from crewai.llm import LLM
|
||||
from crewai import Agent
|
||||
|
||||
class GuideOutline(BaseModel):
|
||||
title: str
|
||||
sections: list[str]
|
||||
|
||||
def test_original_issue():
|
||||
"""Test the original issue scenario from the GitHub issue."""
|
||||
print("Testing original issue scenario...")
|
||||
|
||||
try:
|
||||
llm = LLM(model="ollama/gemma3:latest", response_format=GuideOutline)
|
||||
print("✅ LLM creation with response_format succeeded")
|
||||
|
||||
params = llm._prepare_completion_params("Test message")
|
||||
if "response_format" not in params or params.get("response_format") is None:
|
||||
print("✅ response_format correctly filtered out for Ollama model")
|
||||
else:
|
||||
print("❌ response_format was not filtered out")
|
||||
|
||||
agent = Agent(
|
||||
role="Guide Creator",
|
||||
goal="Create comprehensive guides",
|
||||
backstory="You are an expert at creating structured guides",
|
||||
llm=llm
|
||||
)
|
||||
print("✅ Agent creation with Ollama LLM succeeded")
|
||||
|
||||
except ValueError as e:
|
||||
if "does not support response_format" in str(e):
|
||||
print(f"❌ Original issue still exists: {e}")
|
||||
return False
|
||||
else:
|
||||
print(f"❌ Unexpected ValueError: {e}")
|
||||
return False
|
||||
except Exception as e:
|
||||
print(f"❌ Unexpected error: {e}")
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
def test_non_ollama_models():
|
||||
"""Test that non-Ollama models still work with response_format."""
|
||||
print("\nTesting non-Ollama models...")
|
||||
|
||||
try:
|
||||
llm = LLM(model="gpt-4", response_format=GuideOutline)
|
||||
params = llm._prepare_completion_params("Test message")
|
||||
|
||||
if params.get("response_format") == GuideOutline:
|
||||
print("✅ Non-Ollama models still include response_format")
|
||||
return True
|
||||
else:
|
||||
print("❌ Non-Ollama models missing response_format")
|
||||
return False
|
||||
|
||||
except Exception as e:
|
||||
print(f"❌ Error with non-Ollama model: {e}")
|
||||
return False
|
||||
|
||||
if __name__ == "__main__":
|
||||
print("Testing Ollama response_format fix...")
|
||||
|
||||
success1 = test_original_issue()
|
||||
success2 = test_non_ollama_models()
|
||||
|
||||
if success1 and success2:
|
||||
print("\n🎉 All tests passed! The fix is working correctly.")
|
||||
else:
|
||||
print("\n💥 Some tests failed. The fix needs more work.")
|
||||
@@ -1689,6 +1689,73 @@ def test_agent_execute_task_with_ollama():
|
||||
assert "AI" in result or "artificial intelligence" in result.lower()
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ollama_model_with_response_format():
|
||||
"""Test that Ollama models work correctly when response_format is provided."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestOutput(BaseModel):
|
||||
result: str
|
||||
|
||||
llm = LLM(
|
||||
model="ollama/llama3.2:3b",
|
||||
base_url="http://localhost:11434",
|
||||
response_format=TestOutput
|
||||
)
|
||||
|
||||
result = llm.call("What is 2+2?")
|
||||
assert result is not None
|
||||
|
||||
agent = Agent(
|
||||
role="test role",
|
||||
goal="test goal",
|
||||
backstory="test backstory",
|
||||
llm=llm
|
||||
)
|
||||
|
||||
output = agent.kickoff("What is 2+2?", response_format=TestOutput)
|
||||
assert output is not None
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_ollama_model_response_format_filtered_in_params():
|
||||
"""Test that response_format is filtered out for Ollama models in _prepare_completion_params."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestOutput(BaseModel):
|
||||
result: str
|
||||
|
||||
llm = LLM(
|
||||
model="ollama/llama3.2:3b",
|
||||
base_url="http://localhost:11434",
|
||||
response_format=TestOutput
|
||||
)
|
||||
|
||||
assert llm._is_ollama_model("ollama/llama3.2:3b") is True
|
||||
assert llm._is_ollama_model("gpt-4") is False
|
||||
|
||||
params = llm._prepare_completion_params("Test message")
|
||||
assert "response_format" not in params or params.get("response_format") is None
|
||||
|
||||
|
||||
def test_non_ollama_model_keeps_response_format():
|
||||
"""Test that non-Ollama models still include response_format in params."""
|
||||
from pydantic import BaseModel
|
||||
|
||||
class TestOutput(BaseModel):
|
||||
result: str
|
||||
|
||||
llm = LLM(
|
||||
model="gpt-4",
|
||||
response_format=TestOutput
|
||||
)
|
||||
|
||||
assert llm._is_ollama_model("gpt-4") is False
|
||||
|
||||
params = llm._prepare_completion_params("Test message")
|
||||
assert params.get("response_format") == TestOutput
|
||||
|
||||
|
||||
@pytest.mark.vcr(filter_headers=["authorization"])
|
||||
def test_agent_with_knowledge_sources():
|
||||
content = "Brandon's favorite color is red and he likes Mexican food."
|
||||
|
||||
Reference in New Issue
Block a user