-
-
Notifications
You must be signed in to change notification settings - Fork 52
feat: Add GET/POST /task/list endpoint #277
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
1a507f8
73817ee
4e9ee5d
4affe39
c8d96e4
55a0d60
cb74d93
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,18 +1,21 @@ | ||
| import asyncio | ||
| import json | ||
| import re | ||
| from typing import Annotated, cast | ||
| from enum import StrEnum | ||
| from typing import Annotated, Any, cast | ||
|
|
||
| import xmltodict | ||
| from fastapi import APIRouter, Depends | ||
| from sqlalchemy import RowMapping, text | ||
| from fastapi import APIRouter, Body, Depends | ||
| from sqlalchemy import bindparam, text | ||
| from sqlalchemy.engine import RowMapping | ||
| from sqlalchemy.ext.asyncio import AsyncConnection | ||
|
|
||
| import config | ||
| import database.datasets | ||
| import database.tasks | ||
| from core.errors import InternalError, TaskNotFoundError | ||
| from routers.dependencies import expdb_connection | ||
| from core.errors import InternalError, NoResultsError, TaskNotFoundError | ||
| from routers.dependencies import Pagination, expdb_connection | ||
| from routers.types import CasualString128, IntegerRange, SystemString64, integer_range_regex | ||
| from schemas.datasets.openml import Task | ||
|
|
||
| router = APIRouter(prefix="/tasks", tags=["tasks"]) | ||
|
|
@@ -158,6 +161,253 @@ async def _fill_json_template( # noqa: C901 | |
| return template.replace("[CONSTANT:base_url]", server_url) | ||
|
|
||
|
|
||
| class TaskStatusFilter(StrEnum): | ||
| """Valid values for the status filter.""" | ||
|
|
||
| ACTIVE = "active" | ||
| DEACTIVATED = "deactivated" | ||
| IN_PREPARATION = "in_preparation" | ||
| ALL = "all" | ||
|
|
||
|
|
||
| QUALITIES_TO_SHOW = [ | ||
| "MajorityClassSize", | ||
| "MaxNominalAttDistinctValues", | ||
| "MinorityClassSize", | ||
| "NumberOfClasses", | ||
| "NumberOfFeatures", | ||
| "NumberOfInstances", | ||
| "NumberOfInstancesWithMissingValues", | ||
| "NumberOfMissingValues", | ||
| "NumberOfNumericFeatures", | ||
| "NumberOfSymbolicFeatures", | ||
| ] | ||
|
|
||
| BASIC_TASK_INPUTS = [ | ||
| "source_data", | ||
| "target_feature", | ||
| "estimation_procedure", | ||
| "evaluation_measures", | ||
| ] | ||
|
|
||
|
|
||
| def _quality_clause(quality: str, range_: str | None) -> str: | ||
| """Return a SQL WHERE clause fragment filtering tasks by a dataset quality range. | ||
|
|
||
| Looks up tasks whose source dataset has the given quality within the range. | ||
| Range can be exact ('100') or a range ('50..200'). | ||
| """ | ||
| if not range_: | ||
| return "" | ||
| if not (match := re.match(integer_range_regex, range_)): | ||
| msg = f"`range_` not a valid range: {range_}" | ||
| raise ValueError(msg) | ||
| start, end = match.groups() | ||
| # end group looks like "..200", strip the ".." prefix to get just the number | ||
| value = f"`value` BETWEEN {start} AND {end[2:]}" if end else f"`value`={start}" | ||
| # nested subquery: find datasets with matching quality, then find tasks using those datasets | ||
| return f""" | ||
| AND t.`task_id` IN ( | ||
| SELECT ti.`task_id` FROM task_inputs ti | ||
| WHERE ti.`input`='source_data' AND ti.`value` IN ( | ||
| SELECT `data` FROM data_quality | ||
| WHERE `quality`='{quality}' AND {value} | ||
| ) | ||
| ) | ||
| """ # noqa: S608 | ||
|
|
||
|
|
||
| @router.post(path="/list", description="Provided for convenience, same as `GET` endpoint.") | ||
| @router.get(path="/list") | ||
| async def list_tasks( # noqa: PLR0913, PLR0912, C901, PLR0915 | ||
| pagination: Annotated[Pagination, Body(default_factory=Pagination)], | ||
| task_type_id: Annotated[int | None, Body(description="Filter by task type id.")] = None, | ||
| tag: Annotated[str | None, SystemString64] = None, | ||
| data_tag: Annotated[str | None, SystemString64] = None, | ||
| status: Annotated[TaskStatusFilter, Body()] = TaskStatusFilter.ACTIVE, | ||
| task_id: Annotated[list[int] | None, Body(description="Filter by task id(s).")] = None, | ||
| data_id: Annotated[list[int] | None, Body(description="Filter by dataset id(s).")] = None, | ||
| data_name: Annotated[str | None, CasualString128] = None, | ||
| number_instances: Annotated[str | None, IntegerRange] = None, | ||
| number_features: Annotated[str | None, IntegerRange] = None, | ||
| number_classes: Annotated[str | None, IntegerRange] = None, | ||
| number_missing_values: Annotated[str | None, IntegerRange] = None, | ||
| expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, | ||
| ) -> list[dict[str, Any]]: | ||
| """List tasks, optionally filtered by type, tag, status, dataset properties, and more.""" | ||
| assert expdb is not None # noqa: S101 | ||
|
|
||
| clauses: list[str] = [] | ||
| parameters: dict[str, Any] = { | ||
| "offset": max(0, pagination.offset), | ||
| "limit": max(0, pagination.limit), | ||
| } | ||
|
|
||
| if status != TaskStatusFilter.ALL: | ||
| clauses.append("AND IFNULL(ds.`status`, 'in_preparation') = :status") | ||
| parameters["status"] = status | ||
|
|
||
| if task_type_id is not None: | ||
| clauses.append("AND t.`ttid` = :task_type_id") | ||
| parameters["task_type_id"] = task_type_id | ||
|
|
||
| if tag is not None: | ||
| clauses.append("AND t.`task_id` IN (SELECT `id` FROM task_tag WHERE `tag` = :tag)") | ||
| parameters["tag"] = tag | ||
|
|
||
| if data_tag is not None: | ||
| clauses.append("AND d.`did` IN (SELECT `id` FROM dataset_tag WHERE `tag` = :data_tag)") | ||
| parameters["data_tag"] = data_tag | ||
|
|
||
| if data_name is not None: | ||
| clauses.append("AND d.`name` = :data_name") | ||
| parameters["data_name"] = data_name | ||
|
|
||
| if task_id is not None: | ||
| if not task_id: | ||
| msg = "No tasks match the search criteria." | ||
| raise NoResultsError(msg) | ||
| clauses.append("AND t.`task_id` IN :task_ids") | ||
| parameters["task_ids"] = task_id | ||
|
|
||
| if data_id is not None: | ||
| if not data_id: | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This just tests if it is an empty list, if I am not mistaken. |
||
| msg = "No tasks match the search criteria." | ||
| raise NoResultsError(msg) | ||
| clauses.append("AND d.`did` IN :data_ids") | ||
| parameters["data_ids"] = data_id | ||
|
|
||
| where_number_instances = _quality_clause("NumberOfInstances", number_instances) | ||
| where_number_features = _quality_clause("NumberOfFeatures", number_features) | ||
| where_number_classes = _quality_clause("NumberOfClasses", number_classes) | ||
| where_number_missing_values = _quality_clause("NumberOfMissingValues", number_missing_values) | ||
|
|
||
| # subquery to get the latest status per dataset (dataset_status is a history table) | ||
| status_subquery = """ | ||
| SELECT ds1.did, ds1.status | ||
| FROM dataset_status ds1 | ||
| WHERE ds1.status_date = ( | ||
| SELECT MAX(ds2.status_date) FROM dataset_status ds2 | ||
| WHERE ds1.did = ds2.did | ||
| ) | ||
| """ | ||
|
|
||
| main_query = text( | ||
| f""" | ||
| SELECT | ||
| t.`task_id`, | ||
| t.`ttid` AS task_type_id, | ||
| tt.`name` AS task_type, | ||
| d.`did`, | ||
| d.`name`, | ||
| d.`format`, | ||
| IFNULL(ds.`status`, 'in_preparation') AS status | ||
| FROM task t | ||
| JOIN task_type tt | ||
| ON tt.`ttid` = t.`ttid` | ||
| JOIN task_inputs ti_source | ||
| ON ti_source.`task_id` = t.`task_id` | ||
| AND ti_source.`input` = 'source_data' | ||
| JOIN dataset d | ||
| ON d.`did` = ti_source.`value` | ||
| LEFT JOIN ({status_subquery}) ds | ||
| ON ds.`did` = d.`did` | ||
| WHERE 1=1 | ||
| {where_number_instances} | ||
| {where_number_features} | ||
| {where_number_classes} | ||
| {where_number_missing_values} | ||
| {" ".join(clauses)} | ||
| GROUP BY t.`task_id`, t.`ttid`, tt.`name`, d.`did`, d.`name`, d.`format`, ds.`status` | ||
| ORDER BY t.`task_id` | ||
| LIMIT :limit OFFSET :offset | ||
| """, # noqa: S608 | ||
| ) | ||
|
|
||
| if task_id is not None: | ||
| main_query = main_query.bindparams(bindparam("task_ids", expanding=True)) | ||
| if data_id is not None: | ||
| main_query = main_query.bindparams(bindparam("data_ids", expanding=True)) | ||
|
|
||
| result = await expdb.execute(main_query, parameters=parameters) | ||
| rows = result.mappings().all() | ||
|
|
||
| if not rows: | ||
| msg = "No tasks match the search criteria." | ||
| raise NoResultsError(msg) | ||
|
|
||
| columns = ["task_id", "task_type_id", "task_type", "did", "name", "format", "status"] | ||
| tasks: dict[int, dict[str, Any]] = { | ||
| row["task_id"]: {col: row[col] for col in columns} for row in rows | ||
| } | ||
| task_ids: list[int] = list(tasks.keys()) | ||
| dataset_ids: list[int] = list({t["did"] for t in tasks.values()}) | ||
|
|
||
| inputs_query = text( | ||
| """ | ||
| SELECT `task_id`, `input`, `value` | ||
| FROM task_inputs | ||
| WHERE `task_id` IN :task_ids | ||
| AND `input` IN :basic_inputs | ||
| """, | ||
| ).bindparams( | ||
| bindparam("task_ids", expanding=True), | ||
| bindparam("basic_inputs", expanding=True), | ||
| ) | ||
| inputs_result = await expdb.execute( | ||
| inputs_query, | ||
| parameters={"task_ids": task_ids, "basic_inputs": BASIC_TASK_INPUTS}, | ||
| ) | ||
| for row in inputs_result.all(): | ||
| tasks[row.task_id].setdefault("input", []).append( | ||
| {"name": row.input, "value": row.value}, | ||
| ) | ||
|
|
||
| qualities_query = text( | ||
| """ | ||
| SELECT `data`, `quality`, `value` | ||
| FROM data_quality | ||
| WHERE `data` IN :dataset_ids | ||
| AND `quality` IN :quality_names | ||
| """, | ||
| ).bindparams( | ||
| bindparam("dataset_ids", expanding=True), | ||
| bindparam("quality_names", expanding=True), | ||
| ) | ||
| qualities_result = await expdb.execute( | ||
| qualities_query, | ||
| parameters={"dataset_ids": dataset_ids, "quality_names": QUALITIES_TO_SHOW}, | ||
| ) | ||
| # multiple tasks can reference the same dataset; map dataset_id -> [task_id, ...] | ||
| did_to_task_ids: dict[int, list[int]] = {} | ||
| for tid, t in tasks.items(): | ||
| did_to_task_ids.setdefault(t["did"], []).append(tid) | ||
| for row in qualities_result.all(): | ||
| for tid in did_to_task_ids.get(row.data, []): | ||
| tasks[tid].setdefault("quality", []).append( | ||
| {"name": row.quality, "value": str(row.value)}, | ||
| ) | ||
|
|
||
| tags_query = text( | ||
| """ | ||
| SELECT `id`, `tag` | ||
| FROM task_tag | ||
| WHERE `id` IN :task_ids | ||
| """, | ||
| ).bindparams(bindparam("task_ids", expanding=True)) | ||
| tags_result = await expdb.execute(tags_query, parameters={"task_ids": task_ids}) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. input, quality, and tag queries are all independent of each other, so we can query the database concurrently (with the |
||
| for row in tags_result.all(): | ||
| tasks[row.id].setdefault("tag", []).append(row.tag) | ||
|
|
||
| # ensure every task has all expected keys even if no related rows were found | ||
| for task in tasks.values(): | ||
| task.setdefault("input", []) | ||
| task.setdefault("quality", []) | ||
| task.setdefault("tag", []) | ||
|
|
||
| return list(tasks.values()) | ||
|
|
||
|
|
||
| @router.get("/{task_id}") | ||
| async def get_task( | ||
| task_id: int, | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.