Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions api/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ class Settings(BaseSettings):
signed_url_expiry_seconds: int = 3600 * 1 # 1 hour
task_backlog_limit: int = 100 # Max number of waiting tasks allowed before rejecting new ones
enable_mcp: bool = True
result_expires_days: int = 30 # Number of days to keep task results

@property
def encoded_storage_key(self) -> bytes:
Expand Down
66 changes: 56 additions & 10 deletions api/common/task_helpers.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from datetime import datetime, timezone
from typing import Any, Dict
from typing import Any, Dict, Optional
from uuid import UUID

import httpx
Expand All @@ -10,11 +10,12 @@
from common.config import settings
from common.logger import logger
from common.redis_manager import redis_manager
from common.schemas import DeleteResponse, TaskStatus
from common.schemas import DeleteResponse, Identity, TaskStatus
from worker import celery_app


@cached(cache=TTLCache(maxsize=128, ttl=5))
def get_task_info(task_id: str) -> Dict[str, Any]:
def _get_task_info(task_id: str) -> Dict[str, Any]:
"""
Fetch task information from Flower API.
So we can use this to provide more detailed task status in the API responses.
Expand Down Expand Up @@ -54,18 +55,63 @@ def get_task_info(task_id: str) -> Dict[str, Any]:
return result


def get_queue_position_logs(task_id: str) -> list[str]:
def create_task(task_name: str, task_queue: str, payload: dict, identity: Identity) -> AsyncResult:
"""
Returns a list containing a log string with the task's queue position.
Unified helper to create a task in Celery.
"""
pos_data = redis_manager.get_queue_position(task_id)
if pos_data:
return [f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"]
try:
return celery_app.send_task(
task_name,
queue=task_queue,
args=[payload],
kwargs=identity.model_dump(),
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")


def get_task_detailed(id: UUID) -> tuple[AsyncResult, dict, list[str]]:
"""
Fetches the task across current Redis storage (Broker and Result Backend).
Returns (AsyncResult, task_info, initial_logs).
Raises 404 if the task is not in Redis (either never existed or has expired).
"""

def get_queue_position(task_id: str) -> Optional[str]:
"""
Inner helper to check the broker and format the queue position log.
"""
pos_data = redis_manager.get_queue_position(task_id)
if pos_data:
return f"Queue {pos_data.queue} position: {pos_data.position} / {pos_data.total}"

return None

result = AsyncResult(str(id), app=celery_app)
logs = []

# Celery reports waiting tasks as PENDING and also unknown tasks as PENDING.
if result.status == TaskStatus.PENDING:
queue_position = get_queue_position(str(id))
if queue_position is None:
# Truly not found
raise HTTPException(status_code=404, detail="Task not found or has expired")

# Keep the queue position logs to return to the user
logs = [queue_position]
else:
# get the running logs of the task if available
if result.info:
if isinstance(result.info, dict):
logs = result.info.get("logs", [])

# Enrich with Flower metadata if available (metrics, worker info, etc)
task_info = _get_task_info(str(id))

return [f"Task not found"]
return result, task_info, logs


def cancel_task(id: UUID, celery_app) -> DeleteResponse:
def cancel_task(id: UUID) -> DeleteResponse:
result = AsyncResult(str(id), app=celery_app)

if result.status in ["SUCCESS", "FAILURE", "REVOKED"]:
Expand Down
46 changes: 13 additions & 33 deletions api/images/router.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from uuid import UUID

from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends

from common.auth import verify_token
from common.schemas import DeleteResponse, Identity
from common.storage import signed_url_for_file
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from common.task_helpers import cancel_task, create_task, get_task_detailed
from images.schemas import (
MODEL_META,
ImageCreateResponse,
Expand All @@ -16,30 +15,21 @@
ImageWorkerResponse,
generate_model_docs,
)
from worker import celery_app

router = APIRouter(
prefix="/images", tags=["Images"], dependencies=[Depends(verify_token)] # This will apply to all routes
)


@router.post("", response_model=ImageCreateResponse, description=generate_model_docs(), operation_id="images_create")
def create(
image_request: ImageRequest,
response: Response,
identity: Identity = Depends(verify_token),
):
try:
result = celery_app.send_task(
image_request.task_name,
queue=image_request.task_queue,
args=[image_request.model_dump()],
kwargs=identity.model_dump(),
)
response.headers["Location"] = f"/images/{result.id}"
return ImageCreateResponse(id=result.id, status=result.status)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
def create(image_request: ImageRequest, identity: Identity = Depends(verify_token)):
result = create_task(
image_request.task_name,
image_request.task_queue,
image_request.model_dump(),
identity,
)
return ImageCreateResponse(id=UUID(str(result.id)), status=result.status)


@router.get(
Expand All @@ -51,18 +41,8 @@ def models():

@router.get("/{id}", response_model=ImageResponse, operation_id="images_get")
def get(id: UUID):
result = AsyncResult(str(id), app=celery_app)

# Initialize response with common fields
response = ImageResponse(id=id, status=result.status, task_info=get_task_info(str(id)))

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

if result.info:
if isinstance(result.info, dict):
response.logs = result.info.get("logs", [])
result, task_info, logs = get_task_detailed(id)
response = ImageResponse(id=id, status=result.status, task_info=task_info, logs=logs)

# Add appropriate fields based on status
if result.successful():
Expand All @@ -79,4 +59,4 @@ def get(id: UUID):

@router.delete("/{id}", response_model=DeleteResponse, operation_id="images_delete")
def delete(id: UUID):
return cancel_task(id, celery_app)
return cancel_task(id)
39 changes: 15 additions & 24 deletions api/texts/router.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,10 @@
from uuid import UUID

from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends

from common.auth import verify_token
from common.schemas import DeleteResponse, Identity, TaskStatus
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from common.schemas import DeleteResponse, Identity
from common.task_helpers import cancel_task, create_task, get_task_detailed
from texts.schemas import (
MODEL_META,
TextCreateResponse,
Expand All @@ -15,24 +14,19 @@
TextWorkerResponse,
generate_model_docs,
)
from worker import celery_app

router = APIRouter(prefix="/texts", tags=["Texts"], dependencies=[Depends(verify_token)])


@router.post("", response_model=TextCreateResponse, operation_id="texts_create", description=generate_model_docs())
def create(text_request: TextRequest, response: Response, identity: Identity = Depends(verify_token)):
try:
result = celery_app.send_task(
text_request.task_name,
queue=text_request.task_queue,
args=[text_request.model_dump()],
kwargs=identity.model_dump(),
)
response.headers["Location"] = f"/texts/{result.id}"
return TextCreateResponse(id=result.id, status=result.status)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
def create(text_request: TextRequest, identity: Identity = Depends(verify_token)):
result = create_task(
text_request.task_name,
text_request.task_queue,
text_request.model_dump(),
identity,
)
return TextCreateResponse(id=UUID(str(result.id)), status=result.status)


@router.get("/models", response_model=TextModelsResponse, summary="List text models", operation_id="texts_list_models")
Expand All @@ -42,19 +36,16 @@ def models():

@router.get("/{id}", response_model=TextResponse, operation_id="texts_get")
def get(id: UUID):
result = AsyncResult(str(id), app=celery_app)
result, task_info, logs = get_task_detailed(id)

# Initialize response with common fields
response = TextResponse(
id=id,
status=result.status,
task_info=get_task_info(str(id)),
task_info=task_info,
logs=logs,
)

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

# Add appropriate fields based on status
if result.successful():
result_data = TextWorkerResponse.model_validate(result.result)
Expand All @@ -67,4 +58,4 @@ def get(id: UUID):

@router.delete("/{id}", response_model=DeleteResponse, operation_id="texts_delete")
def delete(id: UUID):
return cancel_task(id, celery_app)
return cancel_task(id)
40 changes: 13 additions & 27 deletions api/videos/router.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
from uuid import UUID

from celery.result import AsyncResult
from fastapi import APIRouter, Depends, HTTPException, Response
from fastapi import APIRouter, Depends

from common.auth import verify_token
from common.schemas import DeleteResponse, Identity
from common.storage import signed_url_for_file
from common.task_helpers import cancel_task, get_queue_position_logs, get_task_info
from common.task_helpers import cancel_task, create_task, get_task_detailed
from videos.schemas import (
MODEL_META,
VideoCreateResponse,
Expand All @@ -16,24 +15,19 @@
VideoWorkerResponse,
generate_model_docs,
)
from worker import celery_app

router = APIRouter(prefix="/videos", tags=["Videos"], dependencies=[Depends(verify_token)])


@router.post("", response_model=VideoCreateResponse, operation_id="videos_create", description=generate_model_docs())
def create(video_request: VideoRequest, response: Response, identity: Identity = Depends(verify_token)):
try:
result = celery_app.send_task(
video_request.task_name,
queue=video_request.task_queue,
args=[video_request.model_dump()],
kwargs=identity.model_dump(),
)
response.headers["Location"] = f"/videos/{result.id}"
return VideoCreateResponse(id=result.id, status=result.status)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Error creating task: {str(e)}")
def create(video_request: VideoRequest, identity: Identity = Depends(verify_token)):
result = create_task(
video_request.task_name,
video_request.task_queue,
video_request.model_dump(),
identity,
)
return VideoCreateResponse(id=UUID(str(result.id)), status=result.status)


@router.get(
Expand All @@ -45,18 +39,10 @@ def models():

@router.get("/{id}", response_model=VideoResponse, operation_id="videos_get")
def get(id: UUID):
result = AsyncResult(str(id), app=celery_app)
result, task_info, logs = get_task_detailed(id)

# Initialize response with common fields
response = VideoResponse(id=id, status=result.status, task_info=get_task_info(str(id)))

# Use the helper to inject queue position into logs if still pending
if result.status == "PENDING":
response.logs = get_queue_position_logs(str(id))

if result.info:
if isinstance(result.info, dict):
response.logs = result.info.get("logs", [])
response = VideoResponse(id=id, status=result.status, task_info=task_info, logs=logs)

# Add appropriate fields based on status
if result.successful():
Expand All @@ -73,4 +59,4 @@ def get(id: UUID):

@router.delete("/{id}", response_model=DeleteResponse, operation_id="videos_delete")
def delete(id: UUID):
return cancel_task(id, celery_app)
return cancel_task(id)
3 changes: 3 additions & 0 deletions api/worker.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
from datetime import timedelta

from celery import Celery

from common.config import settings
Expand All @@ -13,3 +15,4 @@
result_backend_always_retry=False, # Do not always retry result backend operations
result_backend_max_retries=2, # Number of retries for result backend operations
)
celery_app.conf.result_expires = timedelta(days=settings.result_expires_days)
Loading
Loading