Compare commits

...

2 Commits

Author SHA1 Message Date
Devin AI
9bdc7b9eef fix: address review comments - add is_litellm guard, remove unused import/var
- Guard _get_llm_extra_kwargs with is_litellm check to avoid passing
  litellm-specific kwargs to non-litellm instructor clients
- Remove unused pytest import
- Remove unused result variable in test_to_pydantic_forwards_api_key
- Add test for non-litellm path returning empty kwargs

Co-Authored-By: João <joao@crewai.com>
2026-03-04 12:21:57 +00:00
Devin AI
9dabb3e81c fix: support remote Ollama server for function calling and pydantic output
Bug 1: InternalInstructor.to_pydantic() now forwards api_base/base_url/api_key
to litellm so that remote Ollama servers are reachable during structured output
parsing.

Bug 2: LLM.supports_function_calling() now falls back to querying the remote
Ollama /api/show endpoint when litellm returns False for Ollama models with a
non-localhost base URL.

Fixes #4694

Co-Authored-By: João <joao@crewai.com>
2026-03-04 12:15:16 +00:00
3 changed files with 574 additions and 2 deletions

View File

@@ -2150,12 +2150,83 @@ class LLM(BaseLLM):
"Please remove response_format or use a supported model."
)
def _is_ollama_model(self) -> bool:
"""Check if the model uses an Ollama provider.
Returns:
True if the model string indicates an Ollama provider.
"""
model_lower = self.model.lower()
return model_lower.startswith(("ollama/", "ollama_chat/"))
def _get_ollama_base_url(self) -> str | None:
"""Get the base URL for the Ollama server.
Returns the configured api_base or base_url if set, otherwise None
(indicating the default localhost would be used by litellm).
Returns:
The Ollama server base URL, or None if not explicitly configured.
"""
return self.api_base or self.base_url
def _check_ollama_function_calling(self) -> bool:
"""Query a remote Ollama server to check if the model supports function calling.
This is used as a fallback when litellm's built-in check fails because it
cannot reach a remote Ollama server (it defaults to localhost:11434).
Returns:
True if the remote Ollama model supports function calling (tools), False otherwise.
"""
ollama_url = self._get_ollama_base_url()
if not ollama_url:
return False
# Extract the model name without the provider prefix
model_name = self.model
if "/" in model_name:
model_name = model_name.split("/", 1)[1]
try:
url = f"{ollama_url.rstrip('/')}/api/show"
response = httpx.post(url, json={"name": model_name}, timeout=5.0)
if response.status_code == 200:
data = response.json()
# Check model capabilities/template for tool support
model_info = data.get("model_info", {})
template = data.get("template", "")
# Check if any key in model_info contains tool-related capabilities
for key, value in model_info.items():
if "tool" in key.lower() and value is True:
return True
# Check if the template mentions tools
if "tools" in template.lower() or ".ToolCalls" in template:
return True
return False
except Exception as e:
logging.debug(f"Failed to query remote Ollama for function calling support: {e!s}")
return False
def supports_function_calling(self) -> bool:
try:
provider = self._get_custom_llm_provider()
return litellm.utils.supports_function_calling(
result = litellm.utils.supports_function_calling(
self.model, custom_llm_provider=provider
)
if result:
return True
# Fallback: For Ollama models with a remote base_url, litellm's check
# fails because it tries to query localhost:11434. Query the remote
# Ollama server directly instead.
if self._is_ollama_model() and self._get_ollama_base_url():
return self._check_ollama_function_calling()
return False
except Exception as e:
logging.error(f"Failed to check function calling support: {e!s}")
return False

View File

@@ -124,6 +124,31 @@ class InternalInstructor(Generic[T]):
pydantic_model = self.to_pydantic()
return pydantic_model.model_dump_json(indent=2)
def _get_llm_extra_kwargs(self) -> dict[str, Any]:
"""Extract extra keyword arguments from the LLM for forwarding to litellm.
This ensures that provider-specific parameters like api_base and base_url
are forwarded to litellm completion calls, which is critical for remote
providers (e.g., remote Ollama servers).
Returns:
A dict of extra keyword arguments to pass to the completion call.
"""
if self.llm is None or isinstance(self.llm, str):
return {}
# Only forward these kwargs for litellm-backed clients; non-litellm
# instructor clients (from_provider) don't accept them and would raise.
if not getattr(self.llm, "is_litellm", False):
return {}
extra: dict[str, Any] = {}
for attr in ("api_base", "base_url", "api_key"):
value = getattr(self.llm, attr, None)
if value is not None:
extra[attr] = value
return extra
def to_pydantic(self) -> T:
"""Generate structured output using the specified Pydantic model.
@@ -145,6 +170,9 @@ class InternalInstructor(Generic[T]):
else:
model_name = self.llm.model
extra_kwargs = self._get_llm_extra_kwargs()
return self._client.chat.completions.create( # type: ignore[no-any-return]
model=model_name, response_model=self.model, messages=messages
model=model_name, response_model=self.model, messages=messages,
**extra_kwargs,
)

View File

@@ -0,0 +1,473 @@
"""Tests for remote Ollama server support (Issue #4694).
Bug 1: InternalInstructor.to_pydantic() doesn't forward api_base/base_url to litellm.
Bug 2: LLM.supports_function_calling() doesn't query remote Ollama for capabilities.
"""
from unittest.mock import Mock, patch
import httpx
from pydantic import BaseModel
from crewai.llm import LLM
from crewai.utilities.internal_instructor import InternalInstructor
class SimpleModel(BaseModel):
name: str
age: int
# =====================================================================
# Bug 1: InternalInstructor forwards api_base/base_url to litellm
# =====================================================================
def _make_instructor_bypass_init(llm: object) -> "InternalInstructor[SimpleModel]":
"""Create an InternalInstructor bypassing __init__ to avoid instructor import.
This is useful for testing helper methods without needing to mock instructor.
"""
inst: InternalInstructor[SimpleModel] = object.__new__(InternalInstructor)
inst.content = "test"
inst.model = SimpleModel
inst.llm = llm
inst.agent = None
inst._client = Mock()
return inst
class TestInternalInstructorForwardsApiBase:
"""Test that InternalInstructor passes api_base/base_url to litellm completion."""
def _make_litellm_instructor(
self, mock_llm: Mock, mock_client: Mock
) -> "InternalInstructor[SimpleModel]":
"""Create an InternalInstructor with litellm path, mock client injected."""
inst = _make_instructor_bypass_init(mock_llm)
inst._client = mock_client
return inst
def test_to_pydantic_forwards_api_base(self) -> None:
"""When LLM has api_base set, it should be forwarded to the create() call."""
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.model = "ollama_chat/mistral-small3.2:24b"
mock_llm.api_base = "http://remote-server:11434"
mock_llm.base_url = None
mock_llm.api_key = None
mock_client = Mock()
mock_client.chat.completions.create.return_value = SimpleModel(
name="Test", age=25
)
inst = self._make_litellm_instructor(mock_llm, mock_client)
result = inst.to_pydantic()
assert isinstance(result, SimpleModel)
assert result.name == "Test"
assert result.age == 25
# Verify api_base was forwarded
call_kwargs = mock_client.chat.completions.create.call_args
assert call_kwargs.kwargs.get("api_base") == "http://remote-server:11434"
def test_to_pydantic_forwards_base_url(self) -> None:
"""When LLM has base_url set, it should be forwarded to the create() call."""
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.model = "ollama/mistral-small3.2:24b"
mock_llm.api_base = None
mock_llm.base_url = "http://remote-server:11434"
mock_llm.api_key = None
mock_client = Mock()
mock_client.chat.completions.create.return_value = SimpleModel(
name="Test", age=30
)
inst = self._make_litellm_instructor(mock_llm, mock_client)
result = inst.to_pydantic()
assert isinstance(result, SimpleModel)
call_kwargs = mock_client.chat.completions.create.call_args
assert call_kwargs.kwargs.get("base_url") == "http://remote-server:11434"
def test_to_pydantic_forwards_api_key(self) -> None:
"""When LLM has api_key set, it should be forwarded to the create() call."""
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.model = "ollama/mistral-small3.2:24b"
mock_llm.api_base = "http://remote-server:11434"
mock_llm.base_url = None
mock_llm.api_key = "test-key-123"
mock_client = Mock()
mock_client.chat.completions.create.return_value = SimpleModel(
name="Test", age=35
)
inst = self._make_litellm_instructor(mock_llm, mock_client)
inst.to_pydantic()
call_kwargs = mock_client.chat.completions.create.call_args
assert call_kwargs.kwargs.get("api_base") == "http://remote-server:11434"
assert call_kwargs.kwargs.get("api_key") == "test-key-123"
def test_to_pydantic_no_extra_kwargs_for_string_llm(self) -> None:
"""When LLM is a string, no extra kwargs should be passed."""
mock_client = Mock()
mock_client.chat.completions.create.return_value = SimpleModel(
name="Test", age=20
)
inst = _make_instructor_bypass_init("openai/gpt-4o")
inst._client = mock_client
inst.to_pydantic()
call_kwargs = mock_client.chat.completions.create.call_args
# No extra kwargs should be present for string LLM
assert "api_base" not in (call_kwargs.kwargs or {})
assert "base_url" not in (call_kwargs.kwargs or {})
assert "api_key" not in (call_kwargs.kwargs or {})
def test_to_pydantic_no_extra_kwargs_when_none(self) -> None:
"""When LLM has no api_base/base_url/api_key, no extra kwargs should be passed."""
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.model = "gpt-4o"
mock_llm.api_base = None
mock_llm.base_url = None
mock_llm.api_key = None
mock_client = Mock()
mock_client.chat.completions.create.return_value = SimpleModel(
name="Test", age=40
)
inst = self._make_litellm_instructor(mock_llm, mock_client)
inst.to_pydantic()
call_kwargs = mock_client.chat.completions.create.call_args
assert "api_base" not in (call_kwargs.kwargs or {})
assert "base_url" not in (call_kwargs.kwargs or {})
assert "api_key" not in (call_kwargs.kwargs or {})
class TestGetLlmExtraKwargs:
"""Test the _get_llm_extra_kwargs helper method."""
def test_returns_empty_for_none_llm(self) -> None:
inst = _make_instructor_bypass_init(llm=None)
assert inst._get_llm_extra_kwargs() == {}
def test_returns_empty_for_string_llm(self) -> None:
inst = _make_instructor_bypass_init(llm="gpt-4o")
assert inst._get_llm_extra_kwargs() == {}
def test_returns_empty_for_non_litellm(self) -> None:
mock_llm = Mock()
mock_llm.is_litellm = False
mock_llm.api_base = "http://remote:11434"
mock_llm.base_url = None
mock_llm.api_key = None
inst = _make_instructor_bypass_init(llm=mock_llm)
assert inst._get_llm_extra_kwargs() == {}
def test_returns_api_base_when_set(self) -> None:
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.api_base = "http://remote:11434"
mock_llm.base_url = None
mock_llm.api_key = None
inst = _make_instructor_bypass_init(llm=mock_llm)
extra = inst._get_llm_extra_kwargs()
assert extra == {"api_base": "http://remote:11434"}
def test_returns_multiple_attrs_when_set(self) -> None:
mock_llm = Mock()
mock_llm.is_litellm = True
mock_llm.api_base = "http://remote:11434"
mock_llm.base_url = "http://remote:11434"
mock_llm.api_key = "secret"
inst = _make_instructor_bypass_init(llm=mock_llm)
extra = inst._get_llm_extra_kwargs()
assert extra == {
"api_base": "http://remote:11434",
"base_url": "http://remote:11434",
"api_key": "secret",
}
# =====================================================================
# Bug 2: LLM.supports_function_calling() with remote Ollama
# =====================================================================
class TestIsOllamaModel:
"""Test the _is_ollama_model helper method."""
def test_ollama_prefix(self) -> None:
llm = LLM(model="ollama/mistral-small3.2:24b", is_litellm=True)
assert llm._is_ollama_model() is True
def test_ollama_chat_prefix(self) -> None:
llm = LLM(model="ollama_chat/mistral-small3.2:24b", is_litellm=True)
assert llm._is_ollama_model() is True
def test_non_ollama_model(self) -> None:
llm = LLM(model="gpt-4o", is_litellm=True)
assert llm._is_ollama_model() is False
def test_openai_model(self) -> None:
llm = LLM(model="openai/gpt-4o", is_litellm=True)
assert llm._is_ollama_model() is False
class TestGetOllamaBaseUrl:
"""Test the _get_ollama_base_url helper method."""
def test_returns_api_base(self) -> None:
llm = LLM(
model="ollama/mistral",
api_base="http://remote:11434",
is_litellm=True,
)
assert llm._get_ollama_base_url() == "http://remote:11434"
def test_returns_base_url(self) -> None:
llm = LLM(
model="ollama/mistral",
base_url="http://remote:11434",
is_litellm=True,
)
assert llm._get_ollama_base_url() == "http://remote:11434"
def test_api_base_takes_precedence(self) -> None:
llm = LLM(
model="ollama/mistral",
api_base="http://api-base:11434",
base_url="http://base-url:11434",
is_litellm=True,
)
assert llm._get_ollama_base_url() == "http://api-base:11434"
def test_returns_none_when_not_set(self) -> None:
llm = LLM(model="ollama/mistral", is_litellm=True)
assert llm._get_ollama_base_url() is None
class TestCheckOllamaFunctionCalling:
"""Test the _check_ollama_function_calling method."""
def test_returns_true_when_tool_key_in_model_info(self) -> None:
"""Remote Ollama returns model_info with tool-related capability."""
llm = LLM(
model="ollama/mistral-small3.2:24b",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {
"general.architecture": "mistral",
"mistral.tool_call": True,
},
"template": "",
}
with patch("httpx.post", return_value=mock_response) as mock_post:
result = llm._check_ollama_function_calling()
assert result is True
mock_post.assert_called_once_with(
"http://remote:11434/api/show",
json={"name": "mistral-small3.2:24b"},
timeout=5.0,
)
def test_returns_true_when_tools_in_template(self) -> None:
"""Remote Ollama returns template mentioning tools."""
llm = LLM(
model="ollama_chat/mistral-small3.2:24b",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {},
"template": "{{- if .ToolCalls }}\n[TOOL_CALLS]{{ range .ToolCalls }}",
}
with patch("httpx.post", return_value=mock_response):
assert llm._check_ollama_function_calling() is True
def test_returns_true_when_tools_keyword_in_template(self) -> None:
"""Remote Ollama returns template with 'tools' keyword."""
llm = LLM(
model="ollama/qwen2.5:32b",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {},
"template": "{{- if .Tools }}\nAvailable tools:\n{{ range .Tools }}",
}
with patch("httpx.post", return_value=mock_response):
assert llm._check_ollama_function_calling() is True
def test_returns_false_when_no_tool_support(self) -> None:
"""Remote Ollama model without tool support."""
llm = LLM(
model="ollama/llama2",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {"general.architecture": "llama"},
"template": "{{ .Prompt }}",
}
with patch("httpx.post", return_value=mock_response):
assert llm._check_ollama_function_calling() is False
def test_returns_false_when_no_base_url(self) -> None:
"""When no base URL is configured, should return False."""
llm = LLM(model="ollama/mistral", is_litellm=True)
assert llm._check_ollama_function_calling() is False
def test_returns_false_on_connection_error(self) -> None:
"""When the remote server is unreachable, should return False."""
llm = LLM(
model="ollama/mistral",
api_base="http://unreachable:11434",
is_litellm=True,
)
with patch("httpx.post", side_effect=httpx.ConnectError("Connection refused")):
assert llm._check_ollama_function_calling() is False
def test_returns_false_on_404(self) -> None:
"""When the remote server returns 404, should return False."""
llm = LLM(
model="ollama/mistral",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 404
with patch("httpx.post", return_value=mock_response):
assert llm._check_ollama_function_calling() is False
def test_strips_trailing_slash_from_url(self) -> None:
"""Trailing slash in api_base should be stripped."""
llm = LLM(
model="ollama/mistral",
api_base="http://remote:11434/",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {"mistral.tool_call": True},
"template": "",
}
with patch("httpx.post", return_value=mock_response) as mock_post:
llm._check_ollama_function_calling()
mock_post.assert_called_once_with(
"http://remote:11434/api/show",
json={"name": "mistral"},
timeout=5.0,
)
class TestSupportsFunctionCallingWithRemoteOllama:
"""Test the full supports_function_calling flow with remote Ollama."""
def test_litellm_returns_true_no_fallback_needed(self) -> None:
"""When litellm says the model supports function calling, use that."""
llm = LLM(model="gpt-4o", is_litellm=True)
with patch(
"litellm.utils.supports_function_calling", return_value=True
):
assert llm.supports_function_calling() is True
def test_remote_ollama_fallback_when_litellm_returns_false(self) -> None:
"""When litellm returns False for Ollama with remote URL, query the server."""
llm = LLM(
model="ollama_chat/mistral-small3.2:24b",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {"mistral.tool_call": True},
"template": "",
}
with patch("litellm.utils.supports_function_calling", return_value=False):
with patch("httpx.post", return_value=mock_response):
assert llm.supports_function_calling() is True
def test_no_fallback_for_non_ollama_models(self) -> None:
"""When litellm returns False for non-Ollama models, don't use fallback."""
llm = LLM(model="gpt-3.5-turbo", is_litellm=True)
with patch("litellm.utils.supports_function_calling", return_value=False):
assert llm.supports_function_calling() is False
def test_no_fallback_without_remote_url(self) -> None:
"""Ollama model without remote URL shouldn't trigger fallback."""
llm = LLM(model="ollama/mistral", is_litellm=True)
with patch("litellm.utils.supports_function_calling", return_value=False):
assert llm.supports_function_calling() is False
def test_fallback_returns_false_when_model_lacks_tools(self) -> None:
"""Ollama model that doesn't support tools returns False even with fallback."""
llm = LLM(
model="ollama/llama2",
api_base="http://remote:11434",
is_litellm=True,
)
mock_response = Mock()
mock_response.status_code = 200
mock_response.json.return_value = {
"model_info": {"general.architecture": "llama"},
"template": "{{ .Prompt }}",
}
with patch("litellm.utils.supports_function_calling", return_value=False):
with patch("httpx.post", return_value=mock_response):
assert llm.supports_function_calling() is False
def test_exception_handling(self) -> None:
"""When litellm raises an exception, should return False gracefully."""
llm = LLM(model="ollama/mistral", api_base="http://remote:11434", is_litellm=True)
with patch(
"litellm.utils.supports_function_calling",
side_effect=Exception("Unexpected error"),
):
assert llm.supports_function_calling() is False