-
Notifications
You must be signed in to change notification settings - Fork 23
Add user-specific rate limits with admin API management #440
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -11,3 +11,4 @@ reference-kernels/ | |
| yoyo.ini | ||
| .venv | ||
| .claude/ | ||
| *.egg-info/ | ||
| Original file line number | Diff line number | Diff line change | ||||||||||||||||
|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
|
|
@@ -388,6 +388,13 @@ async def run_submission( # noqa: C901 | |||||||||||||||||
| StreamingResponse: A streaming response containing the status and results of the submission. | ||||||||||||||||||
| """ | ||||||||||||||||||
| await simple_rate_limit() | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check user-specific rate limits | ||||||||||||||||||
| with db_context as db: | ||||||||||||||||||
| rate_check = db.check_user_submission_rate(user_info["user_id"]) | ||||||||||||||||||
| if not rate_check["allowed"]: | ||||||||||||||||||
| raise HTTPException(status_code=429, detail=f"Rate limit exceeded. {rate_check['retry_after']}") | ||||||||||||||||||
|
|
||||||||||||||||||
| submission_request, submission_mode_enum = await to_submit_info( | ||||||||||||||||||
| user_info, submission_mode, file, leaderboard_name, gpu_type, db_context | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
@@ -449,6 +456,11 @@ async def run_submission_async( | |||||||||||||||||
| await simple_rate_limit() | ||||||||||||||||||
| logger.info(f"Received submission request for {leaderboard_name} {gpu_type} {submission_mode}") | ||||||||||||||||||
|
|
||||||||||||||||||
| # Check user-specific rate limits | ||||||||||||||||||
| with db_context as db: | ||||||||||||||||||
| rate_check = db.check_user_submission_rate(user_info["user_id"]) | ||||||||||||||||||
| if not rate_check["allowed"]: | ||||||||||||||||||
| raise HTTPException(status_code=429, detail=f"Rate limit exceeded. {rate_check['retry_after']}") | ||||||||||||||||||
|
|
||||||||||||||||||
| # throw error if submission request is invalid | ||||||||||||||||||
| try: | ||||||||||||||||||
|
|
@@ -643,6 +655,85 @@ async def admin_update_problems( | |||||||||||||||||
| } | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @app.get("/admin/rate-limits") | ||||||||||||||||||
| async def get_all_rate_limits( | ||||||||||||||||||
| _: Annotated[None, Depends(require_admin)], | ||||||||||||||||||
| db_context=Depends(get_db), | ||||||||||||||||||
| ) -> dict: | ||||||||||||||||||
| """Get all user rate limit overrides.""" | ||||||||||||||||||
| with db_context as db: | ||||||||||||||||||
| rate_limits = db.get_all_user_rate_limits() | ||||||||||||||||||
| return {"status": "ok", "rate_limits": rate_limits} | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @app.get("/admin/rate-limits/{user_id}") | ||||||||||||||||||
| async def get_user_rate_limit( | ||||||||||||||||||
| user_id: str, | ||||||||||||||||||
| _: Annotated[None, Depends(require_admin)], | ||||||||||||||||||
| db_context=Depends(get_db), | ||||||||||||||||||
| ) -> dict: | ||||||||||||||||||
| """Get rate limit for a specific user.""" | ||||||||||||||||||
| with db_context as db: | ||||||||||||||||||
| rate_limit = db.get_user_rate_limit(user_id) | ||||||||||||||||||
| if rate_limit is None: | ||||||||||||||||||
| raise HTTPException(status_code=404, detail="No rate limit override found for this user") | ||||||||||||||||||
| return {"status": "ok", "rate_limit": rate_limit} | ||||||||||||||||||
|
|
||||||||||||||||||
|
|
||||||||||||||||||
| @app.put("/admin/rate-limits/{user_id}") | ||||||||||||||||||
| async def set_user_rate_limit( | ||||||||||||||||||
| user_id: str, | ||||||||||||||||||
| payload: dict, | ||||||||||||||||||
| _: Annotated[None, Depends(require_admin)], | ||||||||||||||||||
| db_context=Depends(get_db), | ||||||||||||||||||
| ) -> dict: | ||||||||||||||||||
|
Comment on lines
+684
to
+689
|
||||||||||||||||||
| """Set or update rate limit for a user. | ||||||||||||||||||
|
|
||||||||||||||||||
| Payload fields: | ||||||||||||||||||
| max_submissions_per_hour (int, optional): Max submissions per hour | ||||||||||||||||||
| max_submissions_per_day (int, optional): Max submissions per day | ||||||||||||||||||
| note (str, optional): Admin note about why the limit was set | ||||||||||||||||||
| """ | ||||||||||||||||||
| max_per_hour = payload.get("max_submissions_per_hour") | ||||||||||||||||||
| max_per_day = payload.get("max_submissions_per_day") | ||||||||||||||||||
| note = payload.get("note") | ||||||||||||||||||
|
|
||||||||||||||||||
| if max_per_hour is None and max_per_day is None: | ||||||||||||||||||
| raise HTTPException( | ||||||||||||||||||
| status_code=400, | ||||||||||||||||||
| detail="At least one of max_submissions_per_hour or max_submissions_per_day is required", | ||||||||||||||||||
| ) | ||||||||||||||||||
|
|
||||||||||||||||||
| if max_per_hour is not None and (not isinstance(max_per_hour, int) or max_per_hour < 0): | ||||||||||||||||||
| raise HTTPException(status_code=400, detail="max_submissions_per_hour must be a non-negative integer") | ||||||||||||||||||
|
|
||||||||||||||||||
| if max_per_day is not None and (not isinstance(max_per_day, int) or max_per_day < 0): | ||||||||||||||||||
|
Comment on lines
+707
to
+710
|
||||||||||||||||||
| if max_per_hour is not None and (not isinstance(max_per_hour, int) or max_per_hour < 0): | |
| raise HTTPException(status_code=400, detail="max_submissions_per_hour must be a non-negative integer") | |
| if max_per_day is not None and (not isinstance(max_per_day, int) or max_per_day < 0): | |
| if max_per_hour is not None and (type(max_per_hour) is not int or max_per_hour < 0): | |
| raise HTTPException(status_code=400, detail="max_submissions_per_hour must be a non-negative integer") | |
| if max_per_day is not None and (type(max_per_day) is not int or max_per_day < 0): |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -1172,6 +1172,218 @@ def cleanup_temp_users(self): | |
| logger.exception("Could not cleanup temp users", exc_info=e) | ||
| raise KernelBotError("Database error while cleaning up temp users") from e | ||
|
|
||
| def get_user_rate_limit(self, user_id: str) -> Optional[dict]: | ||
| """ | ||
| Get the rate limit override for a specific user. | ||
|
|
||
| Returns: | ||
| Optional[dict]: Rate limit info or None if no override exists. | ||
| """ | ||
| try: | ||
| self.cursor.execute( | ||
| """ | ||
| SELECT user_id, max_submissions_per_hour, max_submissions_per_day, | ||
| note, created_at, updated_at | ||
| FROM leaderboard.user_rate_limits | ||
| WHERE user_id = %s | ||
| """, | ||
| (user_id,), | ||
| ) | ||
| row = self.cursor.fetchone() | ||
| if row is None: | ||
| return None | ||
| return { | ||
| "user_id": row[0], | ||
| "max_submissions_per_hour": row[1], | ||
| "max_submissions_per_day": row[2], | ||
| "note": row[3], | ||
| "created_at": row[4], | ||
| "updated_at": row[5], | ||
| } | ||
| except psycopg2.Error as e: | ||
| self.connection.rollback() | ||
| logger.exception("Error fetching rate limit for user %s", user_id, exc_info=e) | ||
| raise KernelBotError("Error fetching user rate limit") from e | ||
|
|
||
| def get_all_user_rate_limits(self) -> list[dict]: | ||
| """ | ||
| Get all user rate limit overrides. | ||
|
|
||
| Returns: | ||
| list[dict]: All user rate limit entries. | ||
| """ | ||
| try: | ||
| self.cursor.execute( | ||
| """ | ||
| SELECT rl.user_id, rl.max_submissions_per_hour, rl.max_submissions_per_day, | ||
| rl.note, rl.created_at, rl.updated_at, ui.user_name | ||
| FROM leaderboard.user_rate_limits rl | ||
| LEFT JOIN leaderboard.user_info ui ON rl.user_id = ui.id | ||
| ORDER BY rl.updated_at DESC | ||
| """ | ||
| ) | ||
| return [ | ||
| { | ||
| "user_id": row[0], | ||
| "max_submissions_per_hour": row[1], | ||
| "max_submissions_per_day": row[2], | ||
| "note": row[3], | ||
| "created_at": row[4], | ||
| "updated_at": row[5], | ||
| "user_name": row[6], | ||
| } | ||
| for row in self.cursor.fetchall() | ||
| ] | ||
| except psycopg2.Error as e: | ||
| self.connection.rollback() | ||
| logger.exception("Error fetching all user rate limits", exc_info=e) | ||
| raise KernelBotError("Error fetching user rate limits") from e | ||
|
|
||
| def set_user_rate_limit( | ||
| self, | ||
| user_id: str, | ||
| max_submissions_per_hour: Optional[int] = None, | ||
| max_submissions_per_day: Optional[int] = None, | ||
| note: Optional[str] = None, | ||
| ) -> dict: | ||
| """ | ||
| Set or update rate limit for a user (upsert). | ||
|
|
||
| Returns: | ||
| dict: The created/updated rate limit entry. | ||
| """ | ||
| try: | ||
| self.cursor.execute( | ||
| """ | ||
| INSERT INTO leaderboard.user_rate_limits | ||
| (user_id, max_submissions_per_hour, max_submissions_per_day, note) | ||
| VALUES (%s, %s, %s, %s) | ||
| ON CONFLICT (user_id) DO UPDATE SET | ||
| max_submissions_per_hour = EXCLUDED.max_submissions_per_hour, | ||
| max_submissions_per_day = EXCLUDED.max_submissions_per_day, | ||
| note = EXCLUDED.note, | ||
| updated_at = NOW() | ||
| RETURNING user_id, max_submissions_per_hour, max_submissions_per_day, | ||
| note, created_at, updated_at | ||
| """, | ||
| (user_id, max_submissions_per_hour, max_submissions_per_day, note), | ||
| ) | ||
|
Comment on lines
+1256
to
+1270
|
||
| row = self.cursor.fetchone() | ||
| self.connection.commit() | ||
| return { | ||
| "user_id": row[0], | ||
| "max_submissions_per_hour": row[1], | ||
| "max_submissions_per_day": row[2], | ||
| "note": row[3], | ||
| "created_at": row[4], | ||
| "updated_at": row[5], | ||
| } | ||
| except psycopg2.Error as e: | ||
| self.connection.rollback() | ||
| logger.exception("Error setting rate limit for user %s", user_id, exc_info=e) | ||
| raise KernelBotError("Error setting user rate limit") from e | ||
|
|
||
| def delete_user_rate_limit(self, user_id: str) -> bool: | ||
| """ | ||
| Delete a user's rate limit override. | ||
|
|
||
| Returns: | ||
| bool: True if a row was deleted, False if no override existed. | ||
| """ | ||
| try: | ||
| self.cursor.execute( | ||
| """ | ||
| DELETE FROM leaderboard.user_rate_limits | ||
| WHERE user_id = %s | ||
| """, | ||
| (user_id,), | ||
| ) | ||
| deleted = self.cursor.rowcount > 0 | ||
| self.connection.commit() | ||
| return deleted | ||
| except psycopg2.Error as e: | ||
| self.connection.rollback() | ||
| logger.exception("Error deleting rate limit for user %s", user_id, exc_info=e) | ||
| raise KernelBotError("Error deleting user rate limit") from e | ||
|
|
||
| def check_user_submission_rate(self, user_id: str) -> dict: | ||
| """ | ||
| Check a user's current submission counts against their rate limits. | ||
|
|
||
| Returns: | ||
| dict with keys: | ||
| - allowed: bool, whether the user can submit | ||
| - hourly_count: int, submissions in the last hour | ||
| - daily_count: int, submissions in the last day | ||
| - hourly_limit: int or None | ||
| - daily_limit: int or None | ||
| - retry_after: str or None, human-readable wait time if blocked | ||
| """ | ||
| try: | ||
| # Get user's rate limits (None means no override) | ||
| rate_limit = self.get_user_rate_limit(user_id) | ||
| if rate_limit is None: | ||
| return { | ||
| "allowed": True, | ||
| "hourly_count": 0, | ||
| "daily_count": 0, | ||
| "hourly_limit": None, | ||
| "daily_limit": None, | ||
| "retry_after": None, | ||
| } | ||
|
|
||
| hourly_limit = rate_limit["max_submissions_per_hour"] | ||
| daily_limit = rate_limit["max_submissions_per_day"] | ||
|
|
||
| # If both limits are None, user is unrestricted | ||
| if hourly_limit is None and daily_limit is None: | ||
| return { | ||
| "allowed": True, | ||
| "hourly_count": 0, | ||
| "daily_count": 0, | ||
| "hourly_limit": None, | ||
| "daily_limit": None, | ||
| "retry_after": None, | ||
| } | ||
|
|
||
| # Count submissions in the last hour and day | ||
| self.cursor.execute( | ||
| """ | ||
| SELECT | ||
| COUNT(*) FILTER (WHERE submission_time > NOW() - INTERVAL '1 hour') AS hourly_count, | ||
| COUNT(*) FILTER (WHERE submission_time > NOW() - INTERVAL '1 day') AS daily_count | ||
| FROM leaderboard.submission | ||
| WHERE user_id = %s | ||
| """, | ||
| (user_id,), | ||
| ) | ||
|
Comment on lines
+1350
to
+1359
|
||
| row = self.cursor.fetchone() | ||
| hourly_count = row[0] | ||
| daily_count = row[1] | ||
|
|
||
| # Check limits | ||
| hourly_exceeded = hourly_limit is not None and hourly_count >= hourly_limit | ||
| daily_exceeded = daily_limit is not None and daily_count >= daily_limit | ||
|
|
||
| retry_after = None | ||
| if hourly_exceeded: | ||
| retry_after = "Try again in up to 1 hour" | ||
| if daily_exceeded: | ||
| retry_after = "Try again in up to 24 hours" | ||
|
|
||
| return { | ||
| "allowed": not (hourly_exceeded or daily_exceeded), | ||
| "hourly_count": hourly_count, | ||
| "daily_count": daily_count, | ||
| "hourly_limit": hourly_limit, | ||
| "daily_limit": daily_limit, | ||
| "retry_after": retry_after, | ||
| } | ||
| except psycopg2.Error as e: | ||
| self.connection.rollback() | ||
| logger.exception("Error checking rate limit for user %s", user_id, exc_info=e) | ||
| raise KernelBotError("Error checking user rate limit") from e | ||
|
|
||
| def validate_cli_id(self, cli_id: str) -> Optional[dict[str, str]]: | ||
| """ | ||
| Validates a CLI ID and returns the associated user ID if valid. | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,27 @@ | ||
| """ | ||
| add-user-rate-limits | ||
| """ | ||
|
|
||
| from yoyo import step | ||
|
|
||
| __depends__ = {'20260108_01_gzSm3-add-submission-status'} | ||
|
|
||
| steps = [ | ||
| step( | ||
| # forward | ||
| """ | ||
| CREATE TABLE leaderboard.user_rate_limits ( | ||
| user_id TEXT PRIMARY KEY REFERENCES leaderboard.user_info(id), | ||
| max_submissions_per_hour INTEGER, | ||
| max_submissions_per_day INTEGER, | ||
| note TEXT, | ||
| created_at TIMESTAMP WITH TIME ZONE DEFAULT NOW(), | ||
| updated_at TIMESTAMP WITH TIME ZONE DEFAULT NOW() | ||
| ); | ||
|
Comment on lines
+13
to
+20
|
||
| """, | ||
| # backward | ||
| """ | ||
| DROP TABLE leaderboard.user_rate_limits; | ||
| """ | ||
| ) | ||
| ] | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For 429 responses, it’s best practice to include a
Retry-Afterheader (seconds or HTTP date). Sinceretry_afteris currently a human string, consider also returning a machine-readable duration and settingHTTPException(..., headers={'Retry-After': ...}).