-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathFastAPI_client.py
More file actions
317 lines (271 loc) · 11.5 KB
/
FastAPI_client.py
File metadata and controls
317 lines (271 loc) · 11.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
# FastAPI_client.py — Robust FastAPI client with retries/backoff + rich error handling
# Usage:
# from FastAPI_client import FastAPIClient, AuthError, RateLimitError, FastAPIHTTPError, FastAPIConnectionError
#
# client = FastAPIClient(base_url="http://127.0.0.1:8799", token=None)
# try:
# res = client.search_docs("market growth", k=5)
# except AuthError as e:
# print("Auth issue:", e)
# except RateLimitError as e:
# print("Rate limited, retry after:", e.retry_after)
# except FastAPIHTTPError as e:
# print("HTTP error:", e.status, e.message)
# except FastAPIConnectionError as e:
# print("Network error:", e)
# else:
# print(res)
from __future__ import annotations
import os
import time
import random
import json
from typing import Any, Dict, Optional
import requests
# =========================
# Custom exception types
# =========================
class FastAPIClientError(Exception):
"""Base error for the FastAPI client."""
class FastAPIConnectionError(FastAPIClientError):
"""Network/connection problems (DNS, timeouts, etc.)."""
class FastAPIHTTPError(FastAPIClientError):
"""HTTP error with status and optional payload."""
def __init__(
self,
status: int,
message: str = "",
payload: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
):
super().__init__(f"HTTP {status}: {message}")
self.status = status
self.message = message
self.payload = payload or {}
self.headers = headers or {}
class AuthError(FastAPIHTTPError):
"""401/403 authorization/authentication errors."""
class RateLimitError(FastAPIHTTPError):
"""429 errors, optionally including Retry-After in seconds."""
def __init__(
self,
status: int,
message: str = "",
payload: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, str]] = None,
):
super().__init__(status, message, payload, headers)
# Retry-After may be absent; when present, it can be seconds or HTTP-date. We handle seconds here.
self.retry_after: Optional[float] = None
ra = (headers or {}).get("Retry-After")
if ra:
try:
self.retry_after = float(ra)
except ValueError:
# If it's a date, you could parse it; we keep None for simplicity.
self.retry_after = None
# ---------------------------------------
# Backward-compatibility alias (optional)
# ---------------------------------------
# If other modules still import the old names, keep these aliases:
MCPHostError = FastAPIClientError
MCPConnectionError = FastAPIConnectionError
MCPHTTPError = FastAPIHTTPError
# =========================
# FastAPI client
# =========================
# Support both new and legacy env vars:
DEFAULT_SERVER = (
os.getenv("FASTAPI_SERVER_URL")
or os.getenv("MCP_SERVER_URL")
or "http://127.0.0.1:8799"
)
class FastAPIClient:
"""
Minimal FastAPI client that calls a server exposing /tools endpoints.
Features:
- Retries 5xx and 429 with exponential backoff + jitter.
- Raises structured exceptions for 4xx/5xx/connection issues.
- Optional Bearer token for auth.
"""
def __init__(
self,
base_url: str = DEFAULT_SERVER,
token: Optional[str] = None,
timeout: float = 60.0,
max_attempts: int = 3,
base_delay: float = 0.5,
backoff_factor: float = 2.0,
max_delay: float = 8.0,
session: Optional[requests.Session] = None,
default_headers: Optional[Dict[str, str]] = None,
):
self.base_url = base_url.rstrip("/")
self.token = token
self.timeout = timeout
self.max_attempts = max(1, max_attempts)
self.base_delay = base_delay
self.backoff_factor = backoff_factor
self.max_delay = max_delay
self.session = session or requests.Session()
# Always ask for/declare JSON
self.default_headers = {"Content-Type": "application/json", "Accept": "application/json"}
if default_headers:
self.default_headers.update(default_headers)
# ------------- convenience tool methods -------------
def ingest_pdf(self, path: str, doc_type: str = "pdf") -> Dict[str, Any]:
"""Ingest a local PDF file into the RAG index."""
payload = {"path": path, "doc_type": doc_type}
return self._post("/tools/ingest_pdf", payload)
def ingest_csv(self, path: str, text_col: Optional[str] = None, doc_type: str = "csv") -> Dict[str, Any]:
"""Ingest a local CSV file into the RAG index (optionally specify a text column)."""
payload: Dict[str, Any] = {"path": path, "doc_type": doc_type}
if text_col:
payload["text_col"] = text_col
return self._post("/tools/ingest_csv", payload)
def search_docs(self, query: str, k: int = 5) -> Dict[str, Any]:
"""Search the RAG index and return top-k hits with source metadata."""
return self._post("/tools/search_docs", {"query": query, "k": int(k)})
def health(self) -> Dict[str, Any]:
"""Optional: ping root (/) if your server provides a health route."""
try:
resp = self.session.get(f"{self.base_url}/", headers=self._headers(), timeout=self.timeout)
if 200 <= resp.status_code < 300:
try:
return resp.json()
except ValueError:
return {"ok": True, "message": "Server responded but not JSON."}
# Surface non-2xx as HTTP errors
raise FastAPIHTTPError(
status=resp.status_code,
message=_message_from_payload(_safe_json(resp)),
payload=_safe_json(resp),
headers=dict(resp.headers),
)
except requests.RequestException as exc:
raise FastAPIConnectionError(f"Request failed: {exc}") from exc
# ------------- core request logic -------------
def _headers(self, extra: Optional[Dict[str, str]] = None) -> Dict[str, str]:
"""Build request headers with optional Bearer token."""
headers = dict(self.default_headers)
if self.token:
headers["Authorization"] = f"Bearer {self.token}"
if extra:
headers.update(extra)
return headers
def _sleep_with_backoff(self, attempt: int, retry_after_seconds: Optional[float] = None) -> None:
"""
Sleep using either server-provided Retry-After or exponential backoff.
Jitter helps avoid thundering herds when multiple clients retry.
"""
if retry_after_seconds is not None:
delay = retry_after_seconds
else:
delay = min(self.base_delay * (self.backoff_factor ** (attempt - 1)), self.max_delay)
jitter = random.uniform(0, delay * 0.2) # up to 20% jitter
time.sleep(delay + jitter)
def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""POST with retries/backoff + structured error handling; returns parsed JSON on success."""
url = f"{self.base_url}{path}"
for attempt in range(1, self.max_attempts + 1):
try:
resp = self.session.post(
url,
headers=self._headers(),
data=json.dumps(payload),
timeout=self.timeout,
)
except requests.RequestException as exc:
# Network-level issue: retry (if attempts left), else raise
if attempt < self.max_attempts:
self._sleep_with_backoff(attempt)
continue
raise FastAPIConnectionError(f"Request failed: {exc}") from exc
# 2xx → success: parse JSON strictly
if 200 <= resp.status_code < 300:
try:
return resp.json()
except ValueError as exc:
# Server didn't return JSON as expected
raise FastAPIHTTPError(
resp.status_code,
"Invalid JSON in response",
headers=dict(resp.headers),
) from exc
# 429 → rate limit: honor Retry-After if present, then retry/raise
if resp.status_code == 429:
err_payload = _safe_json(resp)
if attempt < self.max_attempts:
retry_after = _retry_after_seconds(resp)
self._sleep_with_backoff(attempt, retry_after_seconds=retry_after)
continue
raise RateLimitError(
status=429,
message=_message_from_payload(err_payload),
payload=err_payload,
headers=dict(resp.headers),
)
# 5xx → server error: retry with backoff, then raise
if 500 <= resp.status_code < 600:
if attempt < self.max_attempts:
self._sleep_with_backoff(attempt)
continue
err_payload = _safe_json(resp)
raise FastAPIHTTPError(
status=resp.status_code,
message=_message_from_payload(err_payload),
payload=err_payload,
headers=dict(resp.headers),
)
# 401/403 → auth/scope problems: do not retry, raise immediately
if resp.status_code in (401, 403):
err_payload = _safe_json(resp)
raise AuthError(
status=resp.status_code,
message=_message_from_payload(err_payload),
payload=err_payload,
headers=dict(resp.headers),
)
# Other 4xx → client error: raise; caller can fix params and retry
if 400 <= resp.status_code < 500:
err_payload = _safe_json(resp)
raise FastAPIHTTPError(
status=resp.status_code,
message=_message_from_payload(err_payload),
payload=err_payload,
headers=dict(resp.headers),
)
# Any unexpected status: treat as HTTP error
err_payload = _safe_json(resp)
raise FastAPIHTTPError(
status=resp.status_code,
message=_message_from_payload(err_payload),
payload=err_payload,
headers=dict(resp.headers),
)
# We should never reach here; loop either returns or raises.
raise FastAPIClientError("Unexpected control flow in _post()")
# =========================
# Helpers (module-level)
# =========================
def _safe_json(resp: requests.Response) -> Dict[str, Any]:
"""Safely parse JSON; on failure, return a minimal dict with raw text."""
try:
return resp.json()
except ValueError:
return {"message": resp.text.strip() or resp.reason or "Unknown error"}
def _retry_after_seconds(resp: requests.Response) -> Optional[float]:
"""Extract Retry-After header as seconds if present and numeric."""
ra = resp.headers.get("Retry-After")
if not ra:
return None
try:
return float(ra)
except ValueError:
return None
def _message_from_payload(payload: Dict[str, Any]) -> str:
"""Pick a concise error message from a typical API payload."""
for key in ("detail", "message", "error", "errors"):
if key in payload:
return str(payload[key])
return "Request failed"