diff --git a/src/dstack/_internal/core/models/keys.py b/src/dstack/_internal/core/models/keys.py new file mode 100644 index 000000000..1a78f1911 --- /dev/null +++ b/src/dstack/_internal/core/models/keys.py @@ -0,0 +1,12 @@ +import datetime +import uuid + +from dstack._internal.core.models.common import CoreModel + + +class PublicKeyInfo(CoreModel): + id: uuid.UUID + added_at: datetime.datetime + name: str + type: str + fingerprint: str diff --git a/src/dstack/_internal/server/app.py b/src/dstack/_internal/server/app.py index 02d536d9c..5f595da01 100644 --- a/src/dstack/_internal/server/app.py +++ b/src/dstack/_internal/server/app.py @@ -42,6 +42,7 @@ metrics, projects, prometheus, + public_keys, repos, runs, secrets, @@ -259,6 +260,7 @@ def register_routes(app: FastAPI, ui: bool = True): app.include_router(exports.project_router) app.include_router(imports.project_router) app.include_router(sshproxy.router) + app.include_router(public_keys.router) @app.exception_handler(ForbiddenError) async def forbidden_error_handler(request: Request, exc: ForbiddenError): diff --git a/src/dstack/_internal/server/migrations/versions/2026/03_24_1145_59e328ced74c_add_userpublickeymodel.py b/src/dstack/_internal/server/migrations/versions/2026/03_24_1145_59e328ced74c_add_userpublickeymodel.py new file mode 100644 index 000000000..6a5e30afb --- /dev/null +++ b/src/dstack/_internal/server/migrations/versions/2026/03_24_1145_59e328ced74c_add_userpublickeymodel.py @@ -0,0 +1,50 @@ +"""Add UserPublicKeyModel + +Revision ID: 59e328ced74c +Revises: c1c2ecaee45c +Create Date: 2026-03-24 11:45:13.560594+00:00 + +""" + +import sqlalchemy as sa +import sqlalchemy_utils +from alembic import op + +import dstack._internal.server.models + +# revision identifiers, used by Alembic. +revision = "59e328ced74c" +down_revision = "c1c2ecaee45c" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.create_table( + "user_public_keys", + sa.Column("id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("created_at", dstack._internal.server.models.NaiveDateTime(), nullable=False), + sa.Column("user_id", sqlalchemy_utils.types.uuid.UUIDType(binary=False), nullable=False), + sa.Column("name", sa.String(length=100), nullable=False), + sa.Column("type", sa.String(length=100), nullable=False), + sa.Column("fingerprint", sa.String(length=100), nullable=False), + sa.Column("key", sa.Text(), nullable=False), + sa.ForeignKeyConstraint( + ["user_id"], + ["users.id"], + name=op.f("fk_user_public_keys_user_id_users"), + ondelete="CASCADE", + ), + sa.PrimaryKeyConstraint("id", name=op.f("pk_user_public_keys")), + sa.UniqueConstraint( + "user_id", "fingerprint", name="uq_user_public_keys_user_id_fingerprint" + ), + ) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_table("user_public_keys") + # ### end Alembic commands ### diff --git a/src/dstack/_internal/server/models.py b/src/dstack/_internal/server/models.py index b599c4314..035ebf554 100644 --- a/src/dstack/_internal/server/models.py +++ b/src/dstack/_internal/server/models.py @@ -1115,3 +1115,24 @@ class ExportedFleetModel(BaseModel): ForeignKey("fleets.id", ondelete="CASCADE"), index=True ) fleet: Mapped["FleetModel"] = relationship() + + +class UserPublicKeyModel(BaseModel): + __tablename__ = "user_public_keys" + __table_args__ = ( + UniqueConstraint("user_id", "fingerprint", name="uq_user_public_keys_user_id_fingerprint"), + ) + + id: Mapped[uuid.UUID] = mapped_column( + UUIDType(binary=False), primary_key=True, default=uuid.uuid4 + ) + created_at: Mapped[datetime] = mapped_column(NaiveDateTime, default=get_current_datetime) + user_id: Mapped[uuid.UUID] = mapped_column(ForeignKey("users.id", ondelete="CASCADE")) + user: Mapped["UserModel"] = relationship() + name: Mapped[str] = mapped_column(String(100)) + type: Mapped[str] = mapped_column(String(100)) + """`type` is a key type identifier used by OpenSSH, e.g., `ssh-rsa`, `ecdsa-sha2-nistp521`.""" + fingerprint: Mapped[str] = mapped_column(String(100)) + """`fingerprint` stores a key digest in the format used by OpenSSH: `SHA256:`.""" + key: Mapped[str] = mapped_column(Text) + """`key` stores a public key in the OpenSSH disk (ASCII-armored) format.""" diff --git a/src/dstack/_internal/server/routers/public_keys.py b/src/dstack/_internal/server/routers/public_keys.py new file mode 100644 index 000000000..2a78bb0e1 --- /dev/null +++ b/src/dstack/_internal/server/routers/public_keys.py @@ -0,0 +1,54 @@ +from typing import Annotated + +from fastapi import APIRouter, Depends +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.models.keys import PublicKeyInfo +from dstack._internal.server.db import get_session +from dstack._internal.server.models import UserModel +from dstack._internal.server.schemas.public_keys import ( + AddPublicKeyRequest, + DeletePublicKeysRequest, +) +from dstack._internal.server.security.permissions import Authenticated +from dstack._internal.server.services import public_keys as public_keys_services +from dstack._internal.server.utils.routers import ( + CustomORJSONResponse, + get_base_api_additional_responses, +) + +router = APIRouter( + prefix="/api/users/public_keys", + tags=["user public keys"], + responses=get_base_api_additional_responses(), +) + + +@router.post("/list", response_model=list[PublicKeyInfo]) +async def list_user_public_keys( + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +): + public_keys = await public_keys_services.list_user_public_keys(session=session, user=user) + return CustomORJSONResponse(public_keys) + + +@router.post("/add", response_model=PublicKeyInfo) +async def add_user_public_key( + body: AddPublicKeyRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +): + public_key = await public_keys_services.add_user_public_key( + session=session, user=user, key=body.key, name=body.name + ) + return CustomORJSONResponse(public_key) + + +@router.post("/delete") +async def delete_user_public_keys( + body: DeletePublicKeysRequest, + session: Annotated[AsyncSession, Depends(get_session)], + user: Annotated[UserModel, Depends(Authenticated())], +): + await public_keys_services.delete_user_public_keys(session=session, user=user, ids=body.ids) diff --git a/src/dstack/_internal/server/schemas/public_keys.py b/src/dstack/_internal/server/schemas/public_keys.py new file mode 100644 index 000000000..97fcee11e --- /dev/null +++ b/src/dstack/_internal/server/schemas/public_keys.py @@ -0,0 +1,13 @@ +import uuid +from typing import Optional + +from dstack._internal.core.models.common import CoreModel + + +class AddPublicKeyRequest(CoreModel): + key: str + name: Optional[str] = None + + +class DeletePublicKeysRequest(CoreModel): + ids: list[uuid.UUID] diff --git a/src/dstack/_internal/server/services/public_keys.py b/src/dstack/_internal/server/services/public_keys.py new file mode 100644 index 000000000..52c642fa3 --- /dev/null +++ b/src/dstack/_internal/server/services/public_keys.py @@ -0,0 +1,258 @@ +import asyncio +import base64 +import hashlib +import subprocess +import uuid +from collections.abc import Iterable +from typing import Any, ClassVar, Optional + +import paramiko.pkey +import sqlalchemy.exc +from sqlalchemy import delete, select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.core.errors import DstackError, ResourceExistsError, ServerClientError +from dstack._internal.core.models.keys import PublicKeyInfo +from dstack._internal.server.models import UserModel, UserPublicKeyModel +from dstack._internal.server.services import events +from dstack._internal.utils.logging import get_logger +from dstack._internal.utils.ssh import find_ssh_util + +logger = get_logger(__name__) + +supported_key_types = [ + "ssh-rsa", + "ecdsa-sha2-nistp256", + "ecdsa-sha2-nistp384", + "ecdsa-sha2-nistp521", + "ssh-ed25519", + "sk-ecdsa-sha2-nistp256@openssh.com", + "sk-ssh-ed25519@openssh.com", +] + + +class PublicKeyError(DstackError): + # The message displayed to the user, should not contain internal/sensitive info + # Any debug info should be passed to the constructor as positional arguments + # and accessed via debug_message() + msg: ClassVar = "Public key error" + + def __init__(self, *args: Any, **kwargs: str) -> None: + super().__init__(*args) + self._kwargs = kwargs + + def __str__(self) -> str: + return self.msg.format(**self._kwargs) + + def debug_message(self) -> str: + return super().__str__() + + +class InvalidPublicKeyError(PublicKeyError): + msg = "Invalid public key, must be in OpenSSH public key format" + + +class UnsupportedPublicKeyError(PublicKeyError): + msg = "Unsupported key type: {type}" + + +async def list_user_public_keys(session: AsyncSession, user: UserModel) -> list[PublicKeyInfo]: + res = await session.execute( + select(UserPublicKeyModel) + .where(UserPublicKeyModel.user_id == user.id) + .order_by(UserPublicKeyModel.created_at.desc()) + ) + user_public_keys = res.scalars().all() + return [user_public_key_model_to_public_key_info(k) for k in user_public_keys] + + +async def add_user_public_key( + session: AsyncSession, user: UserModel, key: str, name: Optional[str] = None +) -> PublicKeyInfo: + try: + type_, blob, comment = parse_openssh_public_key(key) + await validate_openssh_public_key(key) + except PublicKeyError as e: + logger.debug("User public key validation error: %s: %s", e, e.debug_message()) + raise ServerClientError(str(e)) + except (TimeoutError, OSError) as e: + logger.warning("Failed to validate user public key: %s", e) + raise ServerClientError("Failed to validate the key. Try later") + + if not name: + name = comment or hashlib.md5(blob).hexdigest() + fingerprint = get_openssh_public_key_fingerprint(blob) + + user_public_key = UserPublicKeyModel( + user=user, + name=name, + type=type_, + fingerprint=fingerprint, + key=key, + ) + try: + async with session.begin_nested(): + session.add(user_public_key) + except sqlalchemy.exc.IntegrityError: + raise ResourceExistsError() + events.emit( + session, + f"Public key added. Fingerprint: {fingerprint}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(user)], + ) + await session.commit() + + return user_public_key_model_to_public_key_info(user_public_key) + + +async def delete_user_public_keys( + session: AsyncSession, user: UserModel, ids: Iterable[uuid.UUID] +) -> None: + res = await session.execute( + delete(UserPublicKeyModel) + .where( + UserPublicKeyModel.user_id == user.id, + UserPublicKeyModel.id.in_(ids), + ) + .returning(UserPublicKeyModel.fingerprint) + ) + for fingerprint in res.scalars().all(): + events.emit( + session, + f"Public key deleted. Fingerprint: {fingerprint}", + actor=events.UserActor.from_user(user), + targets=[events.Target.from_model(user)], + ) + await session.commit() + + +def parse_openssh_public_key(key: str) -> tuple[str, bytes, Optional[str]]: + """ + Parses OpenSSH public key in disk format. + + Args: + key: public key file contents. + + Returns: + key type, blob in wire format, and optional comment. + + Raises: + InvalidPublicKeyError: if the key disk format is not valid or the declared disk format + key type does not match the actual key type in the blob. + Note, the key blob is not checked, further validation is required. + UnsupportedPublicKeyError: if the key type is not supported. + """ + # OpenSSH disk (ASCII-armored) format for public keys: + # [ ] + # See: section 4.1 "Public key format" + # https://cvsweb.openbsd.org/checkout/src/usr.bin/ssh/PROTOCOL + # e.g., + # * without comment: + # ssh-ed25519 AAAAC3NzaC1lZ[...truncated...] + # * with default comment added by ssh-keygen: + # ssh-rsa AAAAB3NzaC1yc2EAAAADAQ[...truncated...] username@hostname + # * with user-provided comment: + # sk-ssh-ed25519@openssh.com AAAAGnN[...truncated...] my FIDO2 key + + # OpenSSH wire format for public keys: + # string certificate or public key format identifier + # byte[n] key/certificate data + # See: https://datatracker.ietf.org/doc/html/rfc4253#section-6.6 + # Where string type is encoded as follows: + # > They are stored as a uint32 containing its length (number of bytes that follow) + # > and zero (= empty string) or more bytes that are the value of the string. + # > Terminating null characters are not used. + # See: https://datatracker.ietf.org/doc/html/rfc4251#section-5 + # e.g., + # 00 00 00 0b 73 73 68 2d 65 64 32 35 35 31 39 |....ssh-ed25519| + + # PublicBlob.from_string() ensures that: + # * there are at least two fields in the disk format: and + # * key type in the disk format (PublicBlob.key_type) matches key type in the wire format + try: + pb = paramiko.pkey.PublicBlob.from_string(key) + except ValueError as e: + raise InvalidPublicKeyError(str(e)) from e + if pb.key_type not in supported_key_types: + raise UnsupportedPublicKeyError(type=pb.key_type) + return pb.key_type, pb.key_blob, pb.comment or None + + +def get_openssh_public_key_fingerprint(key_blob: bytes) -> str: + """ + Returns OpenSSH public key fingerprint in the format used by OpenSSH. + + See `paramiko.pkey.PKey.fingerprint` for the implementation. + + Args: + key_blob: public key blob in OpenSSH wire format. + + Returns: + A fingerprint as an ASCII string, the same format OpenSSH uses. + """ + sha256_digest_armored = base64.b64encode(hashlib.sha256(key_blob).digest()).decode() + return f"SHA256:{sha256_digest_armored.rstrip('=')}" + + +async def validate_openssh_public_key(key: str) -> None: + """ + Validates OpenSSH public key in disk format using `ssh-keygen`. + + Args: + key: public key file contents. + + Raises: + InvalidPublicKeyError: the key is not valid - `ssh-keygen` returned non-zero exit status. + TimeoutError: validation timeout expired. + OSerror: failed to execute `ssh-keygen` subprocess. + """ + proc = None + try: + proc = await asyncio.create_subprocess_exec( + _get_ssh_keygen_executable(), + "-l", + "-f", + "-", + stdin=subprocess.PIPE, + stdout=subprocess.PIPE, + stderr=subprocess.STDOUT, + ) + output, _ = await asyncio.wait_for(proc.communicate(input=key.encode()), timeout=3) + except asyncio.TimeoutError: + if proc is not None: + proc.kill() + raise TimeoutError("Validation timeout expired") + except OSError: + if proc is not None: + proc.kill() + raise + if proc.returncode != 0: + raise InvalidPublicKeyError(output) + + +def user_public_key_model_to_public_key_info( + user_public_key_model: UserPublicKeyModel, +) -> PublicKeyInfo: + return PublicKeyInfo( + id=user_public_key_model.id, + added_at=user_public_key_model.created_at, + name=user_public_key_model.name, + type=user_public_key_model.type, + fingerprint=user_public_key_model.fingerprint, + ) + + +_ssh_keygen_executable: Optional[str] = None + + +def _get_ssh_keygen_executable() -> str: + global _ssh_keygen_executable + if _ssh_keygen_executable is not None: + return _ssh_keygen_executable + ssh_keygen_path = find_ssh_util("ssh-keygen") + if ssh_keygen_path is None: + _ssh_keygen_executable = "ssh-keygen" + else: + _ssh_keygen_executable = str(ssh_keygen_path) + return _ssh_keygen_executable diff --git a/src/dstack/_internal/server/services/sshproxy/handlers.py b/src/dstack/_internal/server/services/sshproxy/handlers.py index 9f0397114..713b3c802 100644 --- a/src/dstack/_internal/server/services/sshproxy/handlers.py +++ b/src/dstack/_internal/server/services/sshproxy/handlers.py @@ -3,7 +3,7 @@ from sqlalchemy import select from sqlalchemy.ext.asyncio import AsyncSession -from sqlalchemy.orm import joinedload +from sqlalchemy.orm import joinedload, load_only from dstack._internal.core.models.runs import JobStatus from dstack._internal.server.models import ( @@ -12,6 +12,7 @@ ProjectModel, RunModel, UserModel, + UserPublicKeyModel, ) from dstack._internal.server.schemas.sshproxy import GetUpstreamResponse, UpstreamHost from dstack._internal.server.services.jobs import get_job_runtime_data, get_job_spec @@ -46,7 +47,7 @@ async def get_upstream_response( ), ( joinedload(JobModel.run, innerjoin=True) - .load_only(RunModel.run_spec) + .load_only(RunModel.run_spec, RunModel.user_id) .joinedload(RunModel.user, innerjoin=True) .load_only(UserModel.ssh_public_key) ), @@ -75,7 +76,12 @@ async def get_upstream_response( if username is not None: hosts[-1].user = username - authorized_keys: set[str] = set() + res = await session.execute( + select(UserPublicKeyModel) + .where(UserPublicKeyModel.user_id == job.run.user_id) + .options(load_only(UserPublicKeyModel.key)) + ) + authorized_keys = {k.key for k in res.scalars().all()} if (run_spec_key := get_run_spec(job.run).ssh_key_pub) is not None: authorized_keys.add(run_spec_key) if (user_key := job.run.user.ssh_public_key) is not None: diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index e38907b01..a3d0068c0 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -114,6 +114,7 @@ RunModel, SecretModel, UserModel, + UserPublicKeyModel, VolumeAttachmentModel, VolumeModel, ) @@ -163,6 +164,28 @@ async def create_user( return user +async def create_user_public_key( + session: AsyncSession, + user: UserModel, + name: str = "test-key", + type: str = "ssh-ed25519", + fingerprint: str = "SHA256:testfingerprint", + key: str = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5", + created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), +) -> UserPublicKeyModel: + user_public_key = UserPublicKeyModel( + user=user, + name=name, + type=type, + fingerprint=fingerprint, + key=key, + created_at=created_at, + ) + session.add(user_public_key) + await session.commit() + return user_public_key + + async def create_project( session: AsyncSession, owner: Optional[UserModel] = None, diff --git a/src/tests/_internal/server/routers/test_public_keys.py b/src/tests/_internal/server/routers/test_public_keys.py new file mode 100644 index 000000000..1af915821 --- /dev/null +++ b/src/tests/_internal/server/routers/test_public_keys.py @@ -0,0 +1,270 @@ +import uuid +from datetime import datetime, timezone +from unittest.mock import AsyncMock + +import pytest +from freezegun import freeze_time +from httpx import AsyncClient +from sqlalchemy import select +from sqlalchemy.ext.asyncio import AsyncSession + +from dstack._internal.server.models import UserPublicKeyModel +from dstack._internal.server.testing.common import ( + create_user, + create_user_public_key, + get_auth_headers, +) +from dstack._internal.server.testing.matchers import SomeUUID4Str + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db") +class TestListUserPublicKeys: + async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): + response = await client.post("/api/users/public_keys/list") + assert response.status_code in [401, 403] + + async def test_lists_own_public_keys(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session) + key = await create_user_public_key( + session=session, + user=user, + name="my-key", + type="ssh-ed25519", + fingerprint="SHA256:testfingerprint", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + response = await client.post( + "/api/users/public_keys/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json() == [ + { + "id": str(key.id), + "added_at": "2023-01-02T03:04:00+00:00", + "name": "my-key", + "type": "ssh-ed25519", + "fingerprint": "SHA256:testfingerprint", + } + ] + + async def test_does_not_list_other_users_keys( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + other_user = await create_user(session=session, name="other_user") + await create_user_public_key(session=session, user=other_user) + response = await client.post( + "/api/users/public_keys/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + assert response.json() == [] + + async def test_returns_keys_in_reverse_chronological_order( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + key1 = await create_user_public_key( + session=session, + user=user, + name="older-key", + fingerprint="SHA256:fingerprint1", + created_at=datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), + ) + key2 = await create_user_public_key( + session=session, + user=user, + name="newer-key", + fingerprint="SHA256:fingerprint2", + created_at=datetime(2023, 1, 2, 3, 5, tzinfo=timezone.utc), + ) + response = await client.post( + "/api/users/public_keys/list", + headers=get_auth_headers(user.token), + ) + assert response.status_code == 200 + data = response.json() + assert len(data) == 2 + assert data[0]["id"] == str(key2.id) + assert data[1]["id"] == str(key1.id) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db") +class TestAddUserPublicKey: + PUBLIC_KEY_NO_COMMENT = "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAA" + PUBLIC_KEY = f"{PUBLIC_KEY_NO_COMMENT} test@example.com" + + @pytest.fixture + def validate_openssh_public_key_mock(self, monkeypatch: pytest.MonkeyPatch) -> AsyncMock: + mock = AsyncMock() + monkeypatch.setattr( + "dstack._internal.server.services.public_keys.validate_openssh_public_key", mock + ) + return mock + + async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/users/public_keys/add", + json={"key": self.PUBLIC_KEY}, + ) + assert response.status_code in [401, 403] + + @freeze_time(datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc)) + async def test_adds_valid_public_key( + self, + session: AsyncSession, + client: AsyncClient, + validate_openssh_public_key_mock: AsyncMock, + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": self.PUBLIC_KEY}, + ) + assert response.status_code == 200 + assert response.json() == { + "id": SomeUUID4Str(), + "type": "ssh-ed25519", + "name": "test@example.com", + "fingerprint": "SHA256:uALbfMqe7g4MMaRS5NMJen38dAEHwtxzR0iX0Ymuc80", + "added_at": "2023-01-02T03:04:00+00:00", + } + validate_openssh_public_key_mock.assert_awaited_once_with(self.PUBLIC_KEY) + + @pytest.mark.usefixtures("validate_openssh_public_key_mock") + async def test_adds_key_with_custom_name(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": self.PUBLIC_KEY, "name": "my-laptop"}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "my-laptop" + + @pytest.mark.usefixtures("validate_openssh_public_key_mock") + async def test_uses_md5_as_name_when_no_comment_and_no_name( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": self.PUBLIC_KEY_NO_COMMENT}, + ) + assert response.status_code == 200 + assert response.json()["name"] == "744e414c6ac55e3f15c1dd48229cbe74" + + @pytest.mark.parametrize( + "key", + [ + pytest.param("sha-rsa-invalid", id="only-one-field"), + pytest.param("ssh-rsa AAAAB3NzaC1kc3M=", id="dsa-declared-as-rsa"), + ], + ) + async def test_returns_400_for_invalid_key( + self, session: AsyncSession, client: AsyncClient, key: str + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": key}, + ) + assert response.status_code == 400 + assert "Invalid public key" in response.json()["detail"][0]["msg"] + + async def test_returns_400_for_unsupported_key( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": "ssh-dss AAAAB3NzaC1kc3M="}, + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["msg"] == "Unsupported key type: ssh-dss" + + @pytest.mark.usefixtures("validate_openssh_public_key_mock") + async def test_returns_400_resource_exists_for_duplicate_key( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + json={"key": self.PUBLIC_KEY}, + ) + assert response.status_code == 200 + response = await client.post( + "/api/users/public_keys/add", + headers=get_auth_headers(user.token), + # The same key, the comment does not matter + json={"key": self.PUBLIC_KEY_NO_COMMENT}, + ) + assert response.status_code == 400 + assert response.json()["detail"][0]["code"] == "resource_exists" + + +@pytest.mark.asyncio +@pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) +@pytest.mark.usefixtures("test_db") +class TestDeleteUserPublicKeys: + async def test_returns_40x_if_not_authenticated(self, client: AsyncClient): + response = await client.post( + "/api/users/public_keys/delete", + json={"ids": [str(uuid.uuid4())]}, + ) + assert response.status_code in [401, 403] + + async def test_deletes_public_key(self, session: AsyncSession, client: AsyncClient): + user = await create_user(session=session) + key = await create_user_public_key(session=session, user=user) + other_key = await create_user_public_key( + session=session, user=user, fingerprint="SHA256:other" + ) + response = await client.post( + "/api/users/public_keys/delete", + headers=get_auth_headers(user.token), + json={"ids": [str(key.id)]}, + ) + assert response.status_code == 200 + res = await session.execute( + select(UserPublicKeyModel).where(UserPublicKeyModel.user_id == user.id) + ) + assert res.scalars().all() == [other_key] + + async def test_silently_ignores_nonexistent_ids( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + response = await client.post( + "/api/users/public_keys/delete", + headers=get_auth_headers(user.token), + json={"ids": [str(uuid.uuid4())]}, + ) + assert response.status_code == 200 + + async def test_does_not_delete_other_users_keys( + self, session: AsyncSession, client: AsyncClient + ): + user = await create_user(session=session) + other_user = await create_user(session=session, name="other_user") + other_user_key = await create_user_public_key(session=session, user=other_user) + response = await client.post( + "/api/users/public_keys/delete", + headers=get_auth_headers(user.token), + json={"ids": [str(other_user_key.id)]}, + ) + assert response.status_code == 200 + res = await session.execute( + select(UserPublicKeyModel).where(UserPublicKeyModel.user_id == other_user.id) + ) + assert res.scalars().all() == [other_user_key] diff --git a/src/tests/_internal/server/routers/test_sshproxy.py b/src/tests/_internal/server/routers/test_sshproxy.py index 2b546d7d6..2b761b43f 100644 --- a/src/tests/_internal/server/routers/test_sshproxy.py +++ b/src/tests/_internal/server/routers/test_sshproxy.py @@ -19,6 +19,7 @@ create_repo, create_run, create_user, + create_user_public_key, get_auth_headers, get_job_provisioning_data, get_job_runtime_data, @@ -97,6 +98,12 @@ async def test_response( session=session, project=project, backend=BackendType.RUNPOD ) user = await create_user(session=session, ssh_public_key="user-key") + await create_user_public_key( + session=session, user=user, fingerprint="SHA256:fp1", key="user-uploaded-key-1" + ) + await create_user_public_key( + session=session, user=user, fingerprint="SHA256:fp2", key="user-uploaded-key-2" + ) repo = await create_repo(session=session, project_id=project.id) run_spec = get_run_spec(repo_id=repo.name, ssh_key_pub="run-spec-key") run = await create_run( @@ -138,6 +145,8 @@ async def test_response( "authorized_keys": unordered( [ "user-key", + "user-uploaded-key-1", + "user-uploaded-key-2", "run-spec-key", ] ),