Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 12 additions & 2 deletions mellea/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,10 @@ class OllamaModelBackend(FormatterBackend):
base_url (str | None): Ollama server endpoint; defaults to
``env(OLLAMA_HOST)`` or ``http://localhost:11434``.
model_options (dict | None): Default model options for generation requests.
timeout (float | None): Request timeout in seconds for the underlying HTTP
client. ``None`` (the default) preserves the upstream ``ollama`` SDK
default. Set this to bound how long a single request will wait when
the Ollama server is overloaded or stalled.

Attributes:
to_mellea_model_opts_map (dict): Mapping from Ollama-specific option names
Expand All @@ -65,6 +69,7 @@ def __init__(
formatter: ChatFormatter | None = None,
base_url: str | None = None,
model_options: dict | None = None,
timeout: float | None = None,
):
"""Initialize an Ollama backend, connecting to the server and pulling the model if needed."""
super().__init__(
Expand All @@ -81,7 +86,12 @@ def __init__(

# Setup the client and ensure that we have the model available.
self._base_url = base_url
self._client = ollama.Client(base_url)
self._timeout = timeout
client_kwargs: dict[str, Any] = {}
if timeout is not None:
client_kwargs["timeout"] = timeout
self._client_kwargs = client_kwargs
self._client = ollama.Client(base_url, **client_kwargs)

self._client_cache = ClientCache(2)

Expand Down Expand Up @@ -207,7 +217,7 @@ def _async_client(self) -> ollama.AsyncClient:

_async_client = self._client_cache.get(key)
if _async_client is None:
_async_client = ollama.AsyncClient(self._base_url)
_async_client = ollama.AsyncClient(self._base_url, **self._client_kwargs)
self._client_cache.put(key, _async_client)
return _async_client

Expand Down
56 changes: 54 additions & 2 deletions test/backends/test_ollama_unit.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,19 @@
from mellea.core import ModelOutputThunk


def _make_backend(model_options: dict | None = None) -> OllamaModelBackend:
def _make_backend(
model_options: dict | None = None, timeout: float | None = None
) -> OllamaModelBackend:
"""Return an OllamaModelBackend with all network calls patched."""
with (
patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True),
patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True),
patch("mellea.backends.ollama.ollama.Client", return_value=MagicMock()),
patch("mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock()),
):
return OllamaModelBackend(model_id="granite3.3:8b", model_options=model_options)
return OllamaModelBackend(
model_id="granite3.3:8b", model_options=model_options, timeout=timeout
)


@pytest.fixture
Expand Down Expand Up @@ -147,5 +151,53 @@ def test_delta_merge_thinking_concatenated():
assert mot._meta["chat_response"].message.thinking == "step 1 step 2"


# --- timeout wiring ---


def test_timeout_default_not_passed_to_clients():
"""When timeout is omitted, it must not be forwarded to the Ollama clients."""
with (
patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True),
patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True),
patch("mellea.backends.ollama.ollama.Client") as mock_client,
patch("mellea.backends.ollama.ollama.AsyncClient") as mock_async_client,
):
OllamaModelBackend(model_id="granite3.3:8b")

_, sync_kwargs = mock_client.call_args
assert "timeout" not in sync_kwargs
_, async_kwargs = mock_async_client.call_args
assert "timeout" not in async_kwargs


def test_timeout_forwarded_to_sync_and_async_clients():
"""When timeout is set, it must reach both ollama.Client and ollama.AsyncClient."""
with (
patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True),
patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True),
patch("mellea.backends.ollama.ollama.Client") as mock_client,
patch("mellea.backends.ollama.ollama.AsyncClient") as mock_async_client,
):
OllamaModelBackend(model_id="granite3.3:8b", timeout=12.5)

_, sync_kwargs = mock_client.call_args
assert sync_kwargs.get("timeout") == 12.5
_, async_kwargs = mock_async_client.call_args
assert async_kwargs.get("timeout") == 12.5


def test_timeout_forwarded_to_new_async_clients_per_event_loop():
"""Newly created AsyncClients (one per event loop) must inherit the timeout."""
backend = _make_backend(timeout=7.0)
with patch(
"mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock()
) as mock_async_client:
backend._client_cache = type(backend._client_cache)(2) # reset cache
_ = backend._async_client

_, async_kwargs = mock_async_client.call_args
assert async_kwargs.get("timeout") == 7.0


if __name__ == "__main__":
pytest.main([__file__, "-v"])
Loading