From df0b4cf1c4cccb3a4a481150ffe74a4dd5bd905d Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Thu, 30 Apr 2026 01:17:46 +0530 Subject: [PATCH 1/2] feat: add LiteLLM as AI gateway agent --- requirements.txt | 1 + src/client/agents/__init__.py | 1 + src/client/agents/litellm_agent.py | 69 ++++++++++++++++ tests/test_litellm_agent.py | 123 +++++++++++++++++++++++++++++ 4 files changed, 194 insertions(+) create mode 100644 src/client/agents/litellm_agent.py create mode 100644 tests/test_litellm_agent.py diff --git a/requirements.txt b/requirements.txt index 4a1668e0..aa9426fa 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ docker==6.1.2 SPARQLWrapper networkx~=2.8.4 anthropic~=0.4.1 +litellm>=1.60.0,<2.0 fschat~=0.2.31 accelerate~=0.23.0 transformers~=4.34.0 \ No newline at end of file diff --git a/src/client/agents/__init__.py b/src/client/agents/__init__.py index 8d0fe88c..d942a583 100644 --- a/src/client/agents/__init__.py +++ b/src/client/agents/__init__.py @@ -1,2 +1,3 @@ from .fastchat_client import FastChatAgent from .http_agent import HTTPAgent +from .litellm_agent import LiteLLMAgent diff --git a/src/client/agents/litellm_agent.py b/src/client/agents/litellm_agent.py new file mode 100644 index 00000000..178fd26f --- /dev/null +++ b/src/client/agents/litellm_agent.py @@ -0,0 +1,69 @@ +import os +import time +from copy import deepcopy +from typing import Any, Dict, List + +from src.typings import AgentClientException, AgentContextLimitException +from ..agent import AgentClient + + +class LiteLLMAgent(AgentClient): + """Agent client that routes to 100+ LLM providers via LiteLLM. + + Model names use LiteLLM format: "provider/model-name", e.g.: + anthropic/claude-sonnet-4-20250514, openai/gpt-4o, + bedrock/anthropic.claude-3-5-sonnet-20241022-v2:0 + + Provider API keys are read from standard environment variables + (ANTHROPIC_API_KEY, OPENAI_API_KEY, etc.). + + See https://docs.litellm.ai/docs/providers for the full list. + """ + + def __init__(self, model, api_key=None, api_base=None, api_args=None, **kwargs): + super().__init__(**kwargs) + if not model: + raise ValueError("model is required (e.g. 'anthropic/claude-sonnet-4-20250514')") + self.model = model + self.api_key = api_key + self.api_base = api_base + self.api_args = api_args or {} + + def inference(self, history: List[dict]) -> str: + import litellm + + messages = [] + for item in history: + role = "assistant" if item["role"] == "agent" else item["role"] + messages.append({"role": role, "content": item["content"]}) + + extra: Dict[str, Any] = {} + if self.api_key: + extra["api_key"] = self.api_key + if self.api_base: + extra["api_base"] = self.api_base + + for attempt in range(3): + try: + response = litellm.completion( + model=self.model, + messages=messages, + drop_params=True, + **self.api_args, + **extra, + ) + return response.choices[0].message.content + except Exception as e: + error_text = str(e).lower() + if any(kw in error_text for kw in ("context", "token", "limit", "exceed", "max")): + raise AgentContextLimitException(str(e)) + qualname = f"{type(e).__module__}.{type(e).__name__}" + if qualname in ( + "litellm.exceptions.AuthenticationError", + "litellm.exceptions.BadRequestError", + "litellm.exceptions.NotFoundError", + ): + raise + print("Warning: ", e) + time.sleep(attempt + 2) + raise Exception("Failed after 3 attempts.") diff --git a/tests/test_litellm_agent.py b/tests/test_litellm_agent.py new file mode 100644 index 00000000..65f78d97 --- /dev/null +++ b/tests/test_litellm_agent.py @@ -0,0 +1,123 @@ +import sys +import types +from unittest import mock +from unittest.mock import MagicMock + +import pytest + +sys.path.insert(0, ".") + +if "fastchat" not in sys.modules: + _fc = types.ModuleType("fastchat") + _fc_model = types.ModuleType("fastchat.model") + _fc_adapter = types.ModuleType("fastchat.model.model_adapter") + _fc_adapter.get_conversation_template = MagicMock() + _fc.model = _fc_model + _fc_model.model_adapter = _fc_adapter + sys.modules.update({"fastchat": _fc, "fastchat.model": _fc_model, "fastchat.model.model_adapter": _fc_adapter}) + +from src.client.agents.litellm_agent import LiteLLMAgent + + +def _mock_resp(content="Hello!"): + msg = MagicMock(); msg.content = content + choice = MagicMock(); choice.message = msg + resp = MagicMock(); resp.choices = [choice] + return resp + + +def _run_inference(agent, history, response): + fake = types.ModuleType("litellm") + fake.completion = MagicMock(return_value=response) + with mock.patch.dict(sys.modules, {"litellm": fake}): + result = agent.inference(history) + return result, fake.completion + + +class TestInit: + def test_default(self): + a = LiteLLMAgent(model="openai/gpt-4o") + assert a.model == "openai/gpt-4o" + assert a.api_key is None + assert a.api_args == {} + + def test_with_params(self): + a = LiteLLMAgent(model="anthropic/claude-sonnet-4-20250514", api_key="sk-test", api_args={"temperature": 0.7}) + assert a.api_key == "sk-test" + assert a.api_args["temperature"] == 0.7 + + def test_model_required(self): + with pytest.raises(ValueError): + LiteLLMAgent(model=None) + + +class TestInference: + def test_basic_completion(self): + a = LiteLLMAgent(model="openai/gpt-4o") + result, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp("42")) + assert result == "42" + comp.assert_called_once() + + def test_role_mapping(self): + a = LiteLLMAgent(model="openai/gpt-4o") + history = [ + {"role": "user", "content": "hi"}, + {"role": "agent", "content": "hello"}, + {"role": "user", "content": "bye"}, + ] + _, comp = _run_inference(a, history, _mock_resp("ok")) + msgs = comp.call_args.kwargs["messages"] + assert msgs[0]["role"] == "user" + assert msgs[1]["role"] == "assistant" + assert msgs[2]["role"] == "user" + + def test_drop_params(self): + a = LiteLLMAgent(model="openai/gpt-4o") + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert comp.call_args.kwargs["drop_params"] is True + + def test_api_key_forwarded(self): + a = LiteLLMAgent(model="openai/gpt-4o", api_key="sk-test") + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert comp.call_args.kwargs["api_key"] == "sk-test" + + def test_api_key_omitted(self): + a = LiteLLMAgent(model="openai/gpt-4o") + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert "api_key" not in comp.call_args.kwargs + + def test_api_args_forwarded(self): + a = LiteLLMAgent(model="openai/gpt-4o", api_args={"temperature": 0.5, "max_tokens": 100}) + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert comp.call_args.kwargs["temperature"] == 0.5 + assert comp.call_args.kwargs["max_tokens"] == 100 + + def test_api_base_forwarded(self): + a = LiteLLMAgent(model="openai/gpt-4o", api_base="https://proxy.local") + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert comp.call_args.kwargs["api_base"] == "https://proxy.local" + + def test_model_forwarded(self): + a = LiteLLMAgent(model="anthropic/claude-sonnet-4-20250514") + _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) + assert comp.call_args.kwargs["model"] == "anthropic/claude-sonnet-4-20250514" + + +class TestRegistration: + def test_importable(self): + from src.client.agents import LiteLLMAgent as Imported + assert Imported is LiteLLMAgent + + def test_subclass(self): + from src.client.agent import AgentClient + assert issubclass(LiteLLMAgent, AgentClient) + + def test_instance_factory(self): + from src.typings import InstanceFactory + factory = InstanceFactory( + module="src.client.agents.litellm_agent.LiteLLMAgent", + parameters={"model": "openai/gpt-4o"}, + ) + agent = factory.create() + assert isinstance(agent, LiteLLMAgent) + assert agent.model == "openai/gpt-4o" From 34fc915bb746771e3d902a48216d47c9491fc4cf Mon Sep 17 00:00:00 2001 From: RheagalFire Date: Thu, 30 Apr 2026 01:28:35 +0530 Subject: [PATCH 2/2] fix: use check_context_limit, guard None content, add retry/error tests --- src/client/agents/litellm_agent.py | 11 ++--- tests/test_litellm_agent.py | 67 ++++++++++++++++++++++++++++++ 2 files changed, 73 insertions(+), 5 deletions(-) diff --git a/src/client/agents/litellm_agent.py b/src/client/agents/litellm_agent.py index 178fd26f..6a3287da 100644 --- a/src/client/agents/litellm_agent.py +++ b/src/client/agents/litellm_agent.py @@ -1,9 +1,8 @@ -import os import time -from copy import deepcopy from typing import Any, Dict, List from src.typings import AgentClientException, AgentContextLimitException +from .http_agent import check_context_limit from ..agent import AgentClient @@ -25,6 +24,7 @@ def __init__(self, model, api_key=None, api_base=None, api_args=None, **kwargs): if not model: raise ValueError("model is required (e.g. 'anthropic/claude-sonnet-4-20250514')") self.model = model + self.model_name = model self.api_key = api_key self.api_base = api_base self.api_args = api_args or {} @@ -52,10 +52,11 @@ def inference(self, history: List[dict]) -> str: **self.api_args, **extra, ) - return response.choices[0].message.content + return response.choices[0].message.content or "" + except AgentClientException as e: + raise e except Exception as e: - error_text = str(e).lower() - if any(kw in error_text for kw in ("context", "token", "limit", "exceed", "max")): + if check_context_limit(str(e)): raise AgentContextLimitException(str(e)) qualname = f"{type(e).__module__}.{type(e).__name__}" if qualname in ( diff --git a/tests/test_litellm_agent.py b/tests/test_litellm_agent.py index 65f78d97..9dcd0af3 100644 --- a/tests/test_litellm_agent.py +++ b/tests/test_litellm_agent.py @@ -17,6 +17,7 @@ sys.modules.update({"fastchat": _fc, "fastchat.model": _fc_model, "fastchat.model.model_adapter": _fc_adapter}) from src.client.agents.litellm_agent import LiteLLMAgent +from src.typings import AgentContextLimitException def _mock_resp(content="Hello!"): @@ -34,10 +35,23 @@ def _run_inference(agent, history, response): return result, fake.completion +def _run_inference_error(agent, history, error): + fake = types.ModuleType("litellm") + fake.completion = MagicMock(side_effect=error) + with mock.patch.dict(sys.modules, {"litellm": fake}): + with mock.patch("src.client.agents.litellm_agent.time.sleep"): + try: + agent.inference(history) + return None, fake.completion + except Exception as e: + return e, fake.completion + + class TestInit: def test_default(self): a = LiteLLMAgent(model="openai/gpt-4o") assert a.model == "openai/gpt-4o" + assert a.model_name == "openai/gpt-4o" assert a.api_key is None assert a.api_args == {} @@ -102,6 +116,55 @@ def test_model_forwarded(self): _, comp = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp()) assert comp.call_args.kwargs["model"] == "anthropic/claude-sonnet-4-20250514" + def test_none_content_returns_empty_string(self): + a = LiteLLMAgent(model="openai/gpt-4o") + result, _ = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp(content=None)) + assert result == "" + + def test_empty_string_content(self): + a = LiteLLMAgent(model="openai/gpt-4o") + result, _ = _run_inference(a, [{"role": "user", "content": "hi"}], _mock_resp(content="")) + assert result == "" + + +class TestRetryAndErrors: + def test_retries_3_times_then_raises(self): + a = LiteLLMAgent(model="openai/gpt-4o") + exc, comp = _run_inference_error( + a, [{"role": "user", "content": "hi"}], RuntimeError("transient") + ) + assert comp.call_count == 3 + assert "Failed after 3 attempts" in str(exc) + + def test_context_limit_raises_immediately(self): + a = LiteLLMAgent(model="openai/gpt-4o") + exc, comp = _run_inference_error( + a, [{"role": "user", "content": "hi"}], + RuntimeError("prompt is too long, exceeds context token limit"), + ) + assert isinstance(exc, AgentContextLimitException) + assert comp.call_count == 1 + + def test_auth_error_not_retried(self): + auth_exc = type("AuthenticationError", (Exception,), {})("bad key") + auth_exc.__class__.__module__ = "litellm.exceptions" + auth_exc.__class__.__qualname__ = "AuthenticationError" + a = LiteLLMAgent(model="openai/gpt-4o") + exc, comp = _run_inference_error( + a, [{"role": "user", "content": "hi"}], auth_exc, + ) + assert comp.call_count == 1 + assert "bad key" in str(exc) + + def test_rate_limit_not_misclassified_as_context_limit(self): + a = LiteLLMAgent(model="openai/gpt-4o") + exc, comp = _run_inference_error( + a, [{"role": "user", "content": "hi"}], + RuntimeError("Rate limit exceeded, please retry"), + ) + assert not isinstance(exc, AgentContextLimitException) + assert comp.call_count == 3 + class TestRegistration: def test_importable(self): @@ -121,3 +184,7 @@ def test_instance_factory(self): agent = factory.create() assert isinstance(agent, LiteLLMAgent) assert agent.model == "openai/gpt-4o" + + def test_model_name_attribute(self): + a = LiteLLMAgent(model="openai/gpt-4o") + assert a.model_name == "openai/gpt-4o"