From fdfb88dc6c20ce5d6e186545703ceb1312eb3c4e Mon Sep 17 00:00:00 2001 From: "Paul S. Schweigert" Date: Tue, 19 May 2026 21:31:14 -0400 Subject: [PATCH] fix: add ollama async timeout Signed-off-by: Paul S. Schweigert --- mellea/backends/ollama.py | 14 ++++++-- test/backends/test_ollama_unit.py | 56 +++++++++++++++++++++++++++++-- 2 files changed, 66 insertions(+), 4 deletions(-) diff --git a/mellea/backends/ollama.py b/mellea/backends/ollama.py index bc436bece..7ef5f3bef 100644 --- a/mellea/backends/ollama.py +++ b/mellea/backends/ollama.py @@ -51,6 +51,10 @@ class OllamaModelBackend(FormatterBackend): base_url (str | None): Ollama server endpoint; defaults to ``env(OLLAMA_HOST)`` or ``http://localhost:11434``. model_options (dict | None): Default model options for generation requests. + timeout (float | None): Request timeout in seconds for the underlying HTTP + client. ``None`` (the default) preserves the upstream ``ollama`` SDK + default. Set this to bound how long a single request will wait when + the Ollama server is overloaded or stalled. Attributes: to_mellea_model_opts_map (dict): Mapping from Ollama-specific option names @@ -65,6 +69,7 @@ def __init__( formatter: ChatFormatter | None = None, base_url: str | None = None, model_options: dict | None = None, + timeout: float | None = None, ): """Initialize an Ollama backend, connecting to the server and pulling the model if needed.""" super().__init__( @@ -81,7 +86,12 @@ def __init__( # Setup the client and ensure that we have the model available. self._base_url = base_url - self._client = ollama.Client(base_url) + self._timeout = timeout + client_kwargs: dict[str, Any] = {} + if timeout is not None: + client_kwargs["timeout"] = timeout + self._client_kwargs = client_kwargs + self._client = ollama.Client(base_url, **client_kwargs) self._client_cache = ClientCache(2) @@ -207,7 +217,7 @@ def _async_client(self) -> ollama.AsyncClient: _async_client = self._client_cache.get(key) if _async_client is None: - _async_client = ollama.AsyncClient(self._base_url) + _async_client = ollama.AsyncClient(self._base_url, **self._client_kwargs) self._client_cache.put(key, _async_client) return _async_client diff --git a/test/backends/test_ollama_unit.py b/test/backends/test_ollama_unit.py index 9ade604b7..41261a159 100644 --- a/test/backends/test_ollama_unit.py +++ b/test/backends/test_ollama_unit.py @@ -14,7 +14,9 @@ from mellea.core import ModelOutputThunk -def _make_backend(model_options: dict | None = None) -> OllamaModelBackend: +def _make_backend( + model_options: dict | None = None, timeout: float | None = None +) -> OllamaModelBackend: """Return an OllamaModelBackend with all network calls patched.""" with ( patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True), @@ -22,7 +24,9 @@ def _make_backend(model_options: dict | None = None) -> OllamaModelBackend: patch("mellea.backends.ollama.ollama.Client", return_value=MagicMock()), patch("mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock()), ): - return OllamaModelBackend(model_id="granite3.3:8b", model_options=model_options) + return OllamaModelBackend( + model_id="granite3.3:8b", model_options=model_options, timeout=timeout + ) @pytest.fixture @@ -147,5 +151,53 @@ def test_delta_merge_thinking_concatenated(): assert mot._meta["chat_response"].message.thinking == "step 1 step 2" +# --- timeout wiring --- + + +def test_timeout_default_not_passed_to_clients(): + """When timeout is omitted, it must not be forwarded to the Ollama clients.""" + with ( + patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True), + patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True), + patch("mellea.backends.ollama.ollama.Client") as mock_client, + patch("mellea.backends.ollama.ollama.AsyncClient") as mock_async_client, + ): + OllamaModelBackend(model_id="granite3.3:8b") + + _, sync_kwargs = mock_client.call_args + assert "timeout" not in sync_kwargs + _, async_kwargs = mock_async_client.call_args + assert "timeout" not in async_kwargs + + +def test_timeout_forwarded_to_sync_and_async_clients(): + """When timeout is set, it must reach both ollama.Client and ollama.AsyncClient.""" + with ( + patch.object(OllamaModelBackend, "_check_ollama_server", return_value=True), + patch.object(OllamaModelBackend, "_pull_ollama_model", return_value=True), + patch("mellea.backends.ollama.ollama.Client") as mock_client, + patch("mellea.backends.ollama.ollama.AsyncClient") as mock_async_client, + ): + OllamaModelBackend(model_id="granite3.3:8b", timeout=12.5) + + _, sync_kwargs = mock_client.call_args + assert sync_kwargs.get("timeout") == 12.5 + _, async_kwargs = mock_async_client.call_args + assert async_kwargs.get("timeout") == 12.5 + + +def test_timeout_forwarded_to_new_async_clients_per_event_loop(): + """Newly created AsyncClients (one per event loop) must inherit the timeout.""" + backend = _make_backend(timeout=7.0) + with patch( + "mellea.backends.ollama.ollama.AsyncClient", return_value=MagicMock() + ) as mock_async_client: + backend._client_cache = type(backend._client_cache)(2) # reset cache + _ = backend._async_client + + _, async_kwargs = mock_async_client.call_args + assert async_kwargs.get("timeout") == 7.0 + + if __name__ == "__main__": pytest.main([__file__, "-v"])