diff --git a/src/core/errors.py b/src/core/errors.py index 361fa990..dea1f50d 100644 --- a/src/core/errors.py +++ b/src/core/errors.py @@ -186,6 +186,7 @@ class AuthenticationRequiredError(ProblemDetailError): uri = "https://openml.org/problems/authentication-required" title = "Authentication Required" _default_status_code = HTTPStatus.UNAUTHORIZED + _default_code = 103 # PHP API doesn't differentiate class AuthenticationFailedError(ProblemDetailError): diff --git a/src/routers/dependencies.py b/src/routers/dependencies.py index 590ae36b..aeca95ab 100644 --- a/src/routers/dependencies.py +++ b/src/routers/dependencies.py @@ -5,7 +5,7 @@ from pydantic import BaseModel from sqlalchemy.ext.asyncio import AsyncConnection -from core.errors import AuthenticationFailedError +from core.errors import AuthenticationFailedError, AuthenticationRequiredError from database.setup import expdb_database, user_database from database.users import APIKey, User @@ -26,15 +26,22 @@ async def fetch_user( api_key: APIKey | None = None, user_data: Annotated[AsyncConnection | None, Depends(userdb_connection)] = None, ) -> User | None: - return await User.fetch(api_key, user_data) if api_key and user_data else None + if not (api_key and user_data): + return None + + user = await User.fetch(api_key, user_data) + if user: + return user + msg = "Invalid API key provided." + raise AuthenticationFailedError(msg) def fetch_user_or_raise( user: Annotated[User | None, Depends(fetch_user)] = None, ) -> User: if user is None: - msg = "Authentication failed" - raise AuthenticationFailedError(msg) + msg = "No API key provided." + raise AuthenticationRequiredError(msg) return user diff --git a/src/routers/openml/datasets.py b/src/routers/openml/datasets.py index fd0c9f0e..18547eee 100644 --- a/src/routers/openml/datasets.py +++ b/src/routers/openml/datasets.py @@ -13,7 +13,6 @@ import database.qualities from core.access import _user_has_access from core.errors import ( - AuthenticationRequiredError, DatasetAdminOnlyError, DatasetNoAccessError, DatasetNoDataFileError, @@ -338,13 +337,9 @@ async def get_dataset_features( async def update_dataset_status( dataset_id: Annotated[int, Body()], status: Annotated[Literal[DatasetStatus.ACTIVE, DatasetStatus.DEACTIVATED], Body()], - user: Annotated[User | None, Depends(fetch_user)], + user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)], ) -> dict[str, str | int]: - if user is None: - msg = "Updating dataset status requires authentication." - raise AuthenticationRequiredError(msg) - dataset = await _get_dataset_raise_otherwise(dataset_id, user, expdb) can_deactivate = dataset.uploader == user.user_id or await user.is_admin() diff --git a/src/routers/openml/study.py b/src/routers/openml/study.py index 6aeb9370..3b3ca2a4 100644 --- a/src/routers/openml/study.py +++ b/src/routers/openml/study.py @@ -18,7 +18,7 @@ ) from core.formatting import _str_to_bool from database.users import User -from routers.dependencies import expdb_connection, fetch_user +from routers.dependencies import expdb_connection, fetch_user, fetch_user_or_raise from schemas.core import Visibility from schemas.study import CreateStudy, Study, StudyStatus, StudyType @@ -62,7 +62,7 @@ class AttachDetachResponse(BaseModel): async def attach_to_study( study_id: Annotated[int, Body()], entity_ids: Annotated[list[int], Body()], - user: Annotated[User | None, Depends(fetch_user)] = None, + user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> AttachDetachResponse: assert expdb is not None # noqa: S101 @@ -99,13 +99,10 @@ async def attach_to_study( @router.post("/") async def create_study( study: CreateStudy, - user: Annotated[User | None, Depends(fetch_user)] = None, + user: Annotated[User, Depends(fetch_user_or_raise)], expdb: Annotated[AsyncConnection, Depends(expdb_connection)] = None, ) -> dict[Literal["study_id"], int]: assert expdb is not None # noqa: S101 - if user is None: - msg = "Creating a study requires authentication." - raise AuthenticationRequiredError(msg) if study.main_entity_type == StudyType.RUN and study.tasks: msg = "Cannot create a run study with tasks." raise StudyInvalidTypeError(msg) diff --git a/tests/dependencies/__init__.py b/tests/dependencies/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/dependencies/fetch_user_test.py b/tests/dependencies/fetch_user_test.py new file mode 100644 index 00000000..f6c31c47 --- /dev/null +++ b/tests/dependencies/fetch_user_test.py @@ -0,0 +1,38 @@ +import pytest +from sqlalchemy.ext.asyncio import AsyncConnection + +from core.errors import AuthenticationFailedError, AuthenticationRequiredError +from database.users import User +from routers.dependencies import fetch_user, fetch_user_or_raise +from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey + + +@pytest.mark.parametrize( + ("api_key", "user"), + [ + (ApiKey.ADMIN, ADMIN_USER), + (ApiKey.OWNER_USER, OWNER_USER), + (ApiKey.SOME_USER, SOME_USER), + ], +) +async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None: + db_user = await fetch_user(api_key, user_data=user_test) + assert isinstance(db_user, User) + assert user.user_id == db_user.user_id + assert set(await user.get_groups()) == set(await db_user.get_groups()) + + +async def test_fetch_user_no_key_no_user() -> None: + assert await fetch_user(api_key=None) is None + + +async def test_fetch_user_invalid_key_raises(user_test: AsyncConnection) -> None: + with pytest.raises(AuthenticationFailedError): + await fetch_user(api_key=ApiKey.INVALID, user_data=user_test) + + +async def test_fetch_user_or_raise_raises_if_no_user() -> None: + # This function calls `fetch_user` through dependency injection, + # so it only needs to correctly handle possible output of `fetch_user`. + with pytest.raises(AuthenticationRequiredError): + fetch_user_or_raise(user=None) diff --git a/tests/routers/openml/migration/datasets_migration_test.py b/tests/routers/openml/migration/datasets_migration_test.py index a411afb9..73874a39 100644 --- a/tests/routers/openml/migration/datasets_migration_test.py +++ b/tests/routers/openml/migration/datasets_migration_test.py @@ -118,16 +118,10 @@ async def test_error_unknown_dataset( assert error["detail"].startswith("No dataset") -@pytest.mark.parametrize( - "api_key", - [None, ApiKey.INVALID], -) async def test_private_dataset_no_user_no_access( py_api: httpx.AsyncClient, - api_key: str | None, ) -> None: - query = f"?api_key={api_key}" if api_key else "" - response = await py_api.get(f"/datasets/130{query}") + response = await py_api.get("/datasets/130") # New response is 403: Forbidden instead of 412: PRECONDITION FAILED assert response.status_code == HTTPStatus.FORBIDDEN diff --git a/tests/routers/openml/setups_tag_test.py b/tests/routers/openml/setups_tag_test.py index 44d31f24..b4a2d991 100644 --- a/tests/routers/openml/setups_tag_test.py +++ b/tests/routers/openml/setups_tag_test.py @@ -13,7 +13,7 @@ async def test_setup_tag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/tag", json={"setup_id": 1, "tag": "test_tag"}) assert response.status_code == HTTPStatus.UNAUTHORIZED assert response.json()["code"] == "103" - assert response.json()["detail"] == "Authentication failed" + assert response.json()["detail"] == "No API key provided." async def test_setup_tag_unknown_setup(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/setups_untag_test.py b/tests/routers/openml/setups_untag_test.py index 5ea8b515..985df7da 100644 --- a/tests/routers/openml/setups_untag_test.py +++ b/tests/routers/openml/setups_untag_test.py @@ -13,7 +13,7 @@ async def test_setup_untag_missing_auth(py_api: httpx.AsyncClient) -> None: response = await py_api.post("/setup/untag", json={"setup_id": 1, "tag": "test_tag"}) assert response.status_code == HTTPStatus.UNAUTHORIZED assert response.json()["code"] == "103" - assert response.json()["detail"] == "Authentication failed" + assert response.json()["detail"] == "No API key provided." async def test_setup_untag_unknown_setup(py_api: httpx.AsyncClient) -> None: diff --git a/tests/routers/openml/users_test.py b/tests/routers/openml/users_test.py deleted file mode 100644 index 7250a115..00000000 --- a/tests/routers/openml/users_test.py +++ /dev/null @@ -1,27 +0,0 @@ -import pytest -from sqlalchemy.ext.asyncio import AsyncConnection - -from database.users import User -from routers.dependencies import fetch_user -from tests.users import ADMIN_USER, OWNER_USER, SOME_USER, ApiKey - - -@pytest.mark.parametrize( - ("api_key", "user"), - [ - (ApiKey.ADMIN, ADMIN_USER), - (ApiKey.OWNER_USER, OWNER_USER), - (ApiKey.SOME_USER, SOME_USER), - ], -) -async def test_fetch_user(api_key: str, user: User, user_test: AsyncConnection) -> None: - db_user = await fetch_user(api_key, user_data=user_test) - assert db_user is not None - assert user.user_id == db_user.user_id - assert set(await user.get_groups()) == set(await db_user.get_groups()) - - -async def test_fetch_user_invalid_key_returns_none(user_test: AsyncConnection) -> None: - assert await fetch_user(api_key=None, user_data=user_test) is None - invalid_key = "f" * 32 - assert await fetch_user(api_key=invalid_key, user_data=user_test) is None