diff --git a/src/dstack/_internal/core/models/runs.py b/src/dstack/_internal/core/models/runs.py index 0c48664a1..56f27b24e 100644 --- a/src/dstack/_internal/core/models/runs.py +++ b/src/dstack/_internal/core/models/runs.py @@ -472,6 +472,29 @@ class JobConnectionInfo(CoreModel): ) ), ] + sshproxy_hostname: Annotated[ + Optional[str], + Field(description="sshproxy hostname. Not set if sshproxy is not configured."), + ] = None + sshproxy_port: Annotated[ + Optional[int], + Field( + description=( + "ssproxy port. Not set if sshproxy is not configured." + " May be not set if it is equal to the default SSH port 22." + ) + ), + ] = None + sshproxy_upstream_id: Annotated[ + Optional[str], + Field( + description=( + "sshproxy identifier for this job. SSH clients send this identifier as a username" + " to indicate which job they wish to connect." + " Not set if sshproxy is not configured." + ) + ), + ] = None class Job(CoreModel): diff --git a/src/dstack/_internal/core/services/ssh/attach.py b/src/dstack/_internal/core/services/ssh/attach.py index 91842665b..2400113f2 100644 --- a/src/dstack/_internal/core/services/ssh/attach.py +++ b/src/dstack/_internal/core/services/ssh/attach.py @@ -28,10 +28,17 @@ _SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P\d+):localhost:(?P\d+)") -class SSHAttach: +class BaseSSHAttach: + """ + A base class for SSH attach implementations. + + Child classes must populate `self.hosts` inside overridden `__init__()` with at least one host + named as a `run_name` argument value. + """ + @classmethod def get_control_sock_path(cls, run_name: str) -> Path: - return ConfigManager().dstack_ssh_dir / f"%r@{run_name}.control.sock" + return ConfigManager().dstack_ssh_dir / f"{run_name}.control.sock" @classmethod def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]: @@ -57,21 +64,16 @@ def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]: def __init__( self, - hostname: str, - ssh_port: int, - container_ssh_port: int, - user: str, - container_user: str, - id_rsa_path: PathLike, - ports_lock: PortsLock, + *, run_name: str, - dockerized: bool, - ssh_proxy: Optional[SSHConnectionParams] = None, + identity_path: PathLike, + ports_lock: PortsLock, + destination: str, service_port: Optional[int] = None, - local_backend: bool = False, bind_address: Optional[str] = None, ): self._attached = False + self._hosts_added_to_ssh_config = False self._ports_lock = ports_lock self.ports = ports_lock.dict() self.run_name = run_name @@ -80,9 +82,9 @@ def __init__( # Cast all path-like values used in configs to FilePath instances for automatic # path normalization in :func:`update_ssh_config`. self.control_sock_path = FilePath(control_sock_path) - self.identity_file = FilePath(id_rsa_path) + self.identity_file = FilePath(identity_path) self.tunnel = SSHTunnel( - destination=f"root@{run_name}", + destination=destination, identity=self.identity_file, forwarded_sockets=ports_to_forwarded_sockets( ports=self.ports, @@ -94,12 +96,92 @@ def __init__( "ExitOnForwardFailure": "yes", }, ) - self.ssh_proxy = ssh_proxy self.service_port = service_port + self.hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {} + + def __enter__(self): + self.attach() + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + self.detach() + + def attach(self): + include_ssh_config(self.ssh_config_path) + self._add_hosts_to_ssh_config() - hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {} - self.hosts = hosts + self._ports_lock.release() + max_retries = 10 + for i in range(max_retries): + try: + self.tunnel.open() + self._attached = True + atexit.register(self.detach) + return + except SSHError: + if i < max_retries - 1: + time.sleep(1) + self._remove_hosts_from_ssh_config() + raise SSHError("Can't connect to the remote host") + + def detach(self): + self._remove_hosts_from_ssh_config() + if not self._attached: + logger.debug("Not attached") + return + self.tunnel.close() + self._attached = False + logger.debug("Detached") + + def _add_hosts_to_ssh_config(self): + if self._hosts_added_to_ssh_config: + return + for host, options in self.hosts.items(): + update_ssh_config(self.ssh_config_path, host, options) + self._hosts_added_to_ssh_config = True + + def _remove_hosts_from_ssh_config(self): + if not self._hosts_added_to_ssh_config: + return + for host in self.hosts: + update_ssh_config(self.ssh_config_path, host, {}) + self._hosts_added_to_ssh_config = False + + +class SSHAttach(BaseSSHAttach): + """ + `SSHAttach` attaches to a job directly, via a backend-specific chain of hosts. + + Used when `dstack-sshproxy` is not configured on the server. + """ + + def __init__( + self, + *, + run_name: str, + identity_path: PathLike, + ports_lock: PortsLock, + hostname: str, + ssh_port: int, + container_ssh_port: int, + user: str, + container_user: str, + dockerized: bool, + ssh_proxy: Optional[SSHConnectionParams] = None, + local_backend: bool = False, + service_port: Optional[int] = None, + bind_address: Optional[str] = None, + ): + super().__init__( + run_name=run_name, + identity_path=identity_path, + ports_lock=ports_lock, + destination=f"root@{run_name}", + service_port=service_port, + bind_address=bind_address, + ) + hosts = self.hosts if local_backend: hosts[run_name] = { "HostName": hostname, @@ -195,47 +277,39 @@ def __init__( "StrictHostKeyChecking": "no", "UserKnownHostsFile": "/dev/null", } - if get_ssh_client_info().supports_multiplexing: - hosts[run_name].update( - { - "ControlMaster": "auto", - "ControlPath": self.control_sock_path, - } - ) - def attach(self): - include_ssh_config(self.ssh_config_path) - for host, options in self.hosts.items(): - update_ssh_config(self.ssh_config_path, host, options) - max_retries = 10 - self._ports_lock.release() - for i in range(max_retries): - try: - self.tunnel.open() - self._attached = True - atexit.register(self.detach) - break - except SSHError: - if i < max_retries - 1: - time.sleep(1) - else: - self.detach() - raise SSHError("Can't connect to the remote host") +class SSHProxyAttach(BaseSSHAttach): + """ + `SSHProxyAttach` attaches to a job via `dstack-sshproxy`. - def detach(self): - if not self._attached: - logger.debug("Not attached") - return - self.tunnel.close() - for host in self.hosts: - update_ssh_config(self.ssh_config_path, host, {}) - self._attached = False - logger.debug("Detached") + Used when `dstack-sshproxy` is configured on the server. + """ - def __enter__(self): - self.attach() - return self - - def __exit__(self, exc_type, exc_val, exc_tb): - self.detach() + def __init__( + self, + *, + run_name: str, + identity_path: PathLike, + ports_lock: PortsLock, + hostname: str, + upstream_id: str, + port: Optional[int] = None, + service_port: Optional[int] = None, + bind_address: Optional[str] = None, + ): + super().__init__( + run_name=run_name, + identity_path=identity_path, + ports_lock=ports_lock, + destination=f"{upstream_id}_root@{run_name}", + service_port=service_port, + bind_address=bind_address, + ) + self.hosts[run_name] = { + "HostName": hostname, + "Port": port or 22, + "User": upstream_id, + "IdentityFile": self.identity_file, + "IdentitiesOnly": "yes", + } diff --git a/src/dstack/_internal/core/services/ssh/tunnel.py b/src/dstack/_internal/core/services/ssh/tunnel.py index e4f7f276e..b107032ab 100644 --- a/src/dstack/_internal/core/services/ssh/tunnel.py +++ b/src/dstack/_internal/core/services/ssh/tunnel.py @@ -5,7 +5,7 @@ import subprocess import tempfile from dataclasses import dataclass -from typing import Dict, Iterable, List, Literal, Optional, Union +from typing import Dict, Iterable, List, Literal, NoReturn, Optional, Union from dstack._internal.core.errors import SSHError from dstack._internal.core.models.instances import SSHConnectionParams @@ -181,9 +181,8 @@ def open(self) -> None: raise SSHError(msg) from e if r.returncode == 0: return - stderr = self._read_log_file() - logger.debug("SSH tunnel failed: %s", stderr) - raise get_ssh_error(stderr) + log_output = self._read_log_file() + self._raise_ssh_error_from_log_output(log_output) async def aopen(self) -> None: await run_async(self._remove_log_file) @@ -199,9 +198,8 @@ async def aopen(self) -> None: raise SSHError(msg) from e if proc.returncode == 0: return - stderr = await run_async(self._read_log_file) - logger.debug("SSH tunnel failed: %s", stderr) - raise get_ssh_error(stderr) + log_output = await run_async(self._read_log_file) + self._raise_ssh_error_from_log_output(log_output) def close(self) -> None: if not os.path.exists(self.control_sock_path): @@ -299,9 +297,13 @@ def _build_proxy_command( ] return "ProxyCommand=" + shlex.join(command) - def _read_log_file(self) -> bytes: - with open(self.log_path, "rb") as f: - return f.read() + def _read_log_file(self) -> Optional[bytes]: + try: + with open(self.log_path, "rb") as f: + return f.read() + except OSError as e: + logger.debug("Failed to read SSH tunnel log file %s: %s", self.log_path, e) + return None def _remove_log_file(self) -> None: try: @@ -311,6 +313,16 @@ def _remove_log_file(self) -> None: except OSError as e: logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e) + def _raise_ssh_error_from_log_output(self, output: Optional[bytes]) -> NoReturn: + if output is None: + msg = "(no log file)" + ssh_error = SSHError() + else: + msg = output + ssh_error = get_ssh_error(output) + logger.debug("SSH tunnel failed: %s", msg) + raise ssh_error + def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike: if isinstance(identity, FilePath): return identity.path diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py index 73af7d302..afcd0df41 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_running.py @@ -34,6 +34,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server import settings as server_settings from dstack._internal.server.background.pipeline_tasks.base import ( Fetcher, Heartbeater, @@ -525,12 +526,16 @@ async def _process_provisioning_status( fmt(context.job_model), context.job_submission.age, ) - ssh_user = job_provisioning_data.username - assert context.run.run_spec.ssh_key_pub is not None - user_ssh_key = context.run.run_spec.ssh_key_pub.strip() - public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] + public_keys = [context.project.ssh_public_key.strip()] + ssh_user: Optional[str] = None + user_ssh_key: Optional[str] = None + if not server_settings.SSHPROXY_ENFORCED: + ssh_user = job_provisioning_data.username + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys.append(user_ssh_key) if job_provisioning_data.backend == BackendType.LOCAL: - user_ssh_key = "" + user_ssh_key = None success = await run_async( _process_provisioning_with_shim, server_ssh_private_keys, @@ -1065,8 +1070,8 @@ def _process_provisioning_with_shim( volumes: list[Volume], registry_auth: Optional[RegistryAuth], public_keys: list[str], - ssh_user: str, - ssh_key: str, + ssh_user: Optional[str], + ssh_key: Optional[str], ) -> bool: job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data) shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT]) @@ -1128,7 +1133,7 @@ def _process_provisioning_with_shim( volume_mounts=volume_mounts, instance_mounts=instance_mounts, gpu_devices=gpu_devices, - host_ssh_user=ssh_user, + host_ssh_user=ssh_user or "", host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, instance_id=jpd.instance_id, @@ -1143,8 +1148,8 @@ def _process_provisioning_with_shim( container_user=container_user, shm_size=job_spec.requirements.resources.shm_size, public_keys=public_keys, - ssh_user=ssh_user, - ssh_key=ssh_key, + ssh_user=ssh_user or "", + ssh_key=ssh_key or "", mounts=volume_mounts, volumes=volumes, instance_mounts=instance_mounts, diff --git a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py index 72d60b581..5f63ac8b5 100644 --- a/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py +++ b/src/dstack/_internal/server/background/scheduled_tasks/running_jobs.py @@ -37,6 +37,7 @@ RunStatus, ) from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint +from dstack._internal.server import settings as server_settings from dstack._internal.server.background.scheduled_tasks.common import get_provisioning_timeout from dstack._internal.server.db import get_db, get_session_ctx from dstack._internal.server.models import ( @@ -382,13 +383,17 @@ async def _process_running_job_provisioning_state( fmt(context.job_model), context.job_submission.age, ) - ssh_user = job_provisioning_data.username - assert context.run.run_spec.ssh_key_pub is not None - user_ssh_key = context.run.run_spec.ssh_key_pub.strip() - public_keys = [context.project.ssh_public_key.strip(), user_ssh_key] + public_keys = [context.project.ssh_public_key.strip()] + ssh_user: Optional[str] = None + user_ssh_key: Optional[str] = None + if not server_settings.SSHPROXY_ENFORCED: + ssh_user = job_provisioning_data.username + assert context.run.run_spec.ssh_key_pub is not None + user_ssh_key = context.run.run_spec.ssh_key_pub.strip() + public_keys.append(user_ssh_key) if job_provisioning_data.backend == BackendType.LOCAL: # No need to update ~/.ssh/authorized_keys when running shim locally - user_ssh_key = "" + user_ssh_key = None success = await common_utils.run_async( _process_provisioning_with_shim, server_ssh_private_keys, @@ -725,8 +730,8 @@ def _process_provisioning_with_shim( volumes: List[Volume], registry_auth: Optional[RegistryAuth], public_keys: List[str], - ssh_user: str, - ssh_key: str, + ssh_user: Optional[str], + ssh_key: Optional[str], ) -> bool: """ Possible next states: @@ -804,7 +809,7 @@ def _process_provisioning_with_shim( volume_mounts=volume_mounts, instance_mounts=instance_mounts, gpu_devices=gpu_devices, - host_ssh_user=ssh_user, + host_ssh_user=ssh_user or "", host_ssh_keys=[ssh_key] if ssh_key else [], container_ssh_keys=public_keys, instance_id=jpd.instance_id, @@ -819,8 +824,8 @@ def _process_provisioning_with_shim( container_user=container_user, shm_size=job_spec.requirements.resources.shm_size, public_keys=public_keys, - ssh_user=ssh_user, - ssh_key=ssh_key, + ssh_user=ssh_user or "", + ssh_key=ssh_key or "", mounts=volume_mounts, volumes=volumes, instance_mounts=instance_mounts, diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index 7bdf49c6d..e37c3a871 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -29,6 +29,7 @@ RunSpec, ) from dstack._internal.core.models.volumes import Volume, VolumeMountPoint, VolumeStatus +from dstack._internal.server import settings from dstack._internal.server.models import ( InstanceModel, JobModel, @@ -56,6 +57,7 @@ from dstack._internal.server.services.sshproxy import ( build_proxied_job_ssh_command, build_proxied_job_ssh_url_authority, + build_proxied_job_upstream_id, ) from dstack._internal.utils import common from dstack._internal.utils.common import run_async @@ -526,12 +528,23 @@ def get_job_connection_info(job_model: JobModel, run_spec: RunSpec) -> JobConnec if proxied_url_authority is not None: proxied_ide_url = ide.get_url(proxied_url_authority, jrd.working_dir) + sshproxy_hostname: Optional[str] = None + sshproxy_port: Optional[int] = None + sshproxy_upstream_id: Optional[str] = None + if settings.SSHPROXY_ENABLED: + sshproxy_hostname = settings.SSHPROXY_HOSTNAME + sshproxy_port = settings.SSHPROXY_PORT + sshproxy_upstream_id = build_proxied_job_upstream_id(job_model) + return JobConnectionInfo( ide_name=ide_name, attached_ide_url=attached_ide_url, proxied_ide_url=proxied_ide_url, attached_ssh_command=build_ssh_command(hostname=attached_hostname), proxied_ssh_command=build_proxied_job_ssh_command(job_model), + sshproxy_hostname=sshproxy_hostname, + sshproxy_port=sshproxy_port, + sshproxy_upstream_id=sshproxy_upstream_id, ) diff --git a/src/dstack/_internal/server/services/sshproxy/__init__.py b/src/dstack/_internal/server/services/sshproxy/__init__.py index 9966ba91f..aa4621a2c 100644 --- a/src/dstack/_internal/server/services/sshproxy/__init__.py +++ b/src/dstack/_internal/server/services/sshproxy/__init__.py @@ -10,7 +10,7 @@ def build_proxied_job_ssh_url_authority(job: JobModel) -> Optional[str]: return None assert settings.SSHPROXY_HOSTNAME is not None return build_ssh_url_authority( - username=_build_proxied_job_username(job), + username=build_proxied_job_upstream_id(job), hostname=settings.SSHPROXY_HOSTNAME, port=settings.SSHPROXY_PORT, ) @@ -21,12 +21,12 @@ def build_proxied_job_ssh_command(job: JobModel) -> Optional[list[str]]: return None assert settings.SSHPROXY_HOSTNAME is not None return build_ssh_command( - username=_build_proxied_job_username(job), + username=build_proxied_job_upstream_id(job), hostname=settings.SSHPROXY_HOSTNAME, port=settings.SSHPROXY_PORT, ) -def _build_proxied_job_username(job: JobModel) -> str: +def build_proxied_job_upstream_id(job: JobModel) -> str: # Job's UUID in lowercase, without dashes return job.id.hex diff --git a/src/dstack/_internal/server/settings.py b/src/dstack/_internal/server/settings.py index e005a6cb5..b177c2959 100644 --- a/src/dstack/_internal/server/settings.py +++ b/src/dstack/_internal/server/settings.py @@ -111,6 +111,10 @@ "DSTACK_SERVER_SSHPROXY_ADDRESS", parse_hostname_port, default=(None, None) ) SSHPROXY_ENABLED = SSHPROXY_API_TOKEN is not None and SSHPROXY_HOSTNAME is not None +SSHPROXY_ENFORCED = os.getenv("DSTACK_SERVER_SSHPROXY_ENFORCED") is not None +if SSHPROXY_ENFORCED and not SSHPROXY_ENABLED: + logger.warning("sshproxy is not enabled, ignoring DSTACK_SERVER_SSHPROXY_ENFORCED") + SSHPROXY_ENFORCED = False SERVER_KEEP_SHIM_TASKS = os.getenv("DSTACK_SERVER_KEEP_SHIM_TASKS") is not None diff --git a/src/dstack/api/_public/runs.py b/src/dstack/api/_public/runs.py index 675a88d29..a8afac24c 100644 --- a/src/dstack/api/_public/runs.py +++ b/src/dstack/api/_public/runs.py @@ -41,7 +41,7 @@ from dstack._internal.core.models.runs import Run as RunModel from dstack._internal.core.services.configs import ConfigManager from dstack._internal.core.services.logs import URLReplacer -from dstack._internal.core.services.ssh.attach import SSHAttach +from dstack._internal.core.services.ssh.attach import BaseSSHAttach, SSHAttach, SSHProxyAttach from dstack._internal.core.services.ssh.key_manager import UserSSHKeyManager from dstack._internal.core.services.ssh.ports import PortsLock from dstack._internal.server.schemas.logs import PollLogsRequest @@ -77,7 +77,7 @@ def __init__( self._project = project self._run = run self._ports_lock: Optional[PortsLock] = ports_lock - self._ssh_attach: Optional[SSHAttach] = None + self._ssh_attach: Optional[BaseSSHAttach] = None if ssh_identity_file is not None: logger.warning( "[code]ssh_identity_file[/code] in [code]Run[/code] is deprecated and ignored; will be removed" @@ -347,37 +347,62 @@ def attach( self._ports_lock.dict(), ) - container_ssh_port = DSTACK_RUNNER_SSH_PORT - runtime_data = latest_job_submission.job_runtime_data - if runtime_data is not None and runtime_data.ports is not None: - container_ssh_port = runtime_data.ports.get(container_ssh_port, container_ssh_port) - - if runtime_data is not None and runtime_data.username is not None: - container_user = runtime_data.username - elif job.job_spec.user is not None and job.job_spec.user.username is not None: - container_user = job.job_spec.user.username - else: - container_user = "root" - service_port = None if isinstance(self._run.run_spec.configuration, ServiceConfiguration): service_port = get_service_port(job.job_spec, self._run.run_spec.configuration) - self._ssh_attach = SSHAttach( - hostname=provisioning_data.hostname, - ssh_port=provisioning_data.ssh_port, - container_ssh_port=container_ssh_port, - user=provisioning_data.username, - container_user=container_user, - id_rsa_path=ssh_identity_file, - ports_lock=self._ports_lock, - run_name=name, - dockerized=provisioning_data.dockerized, - ssh_proxy=provisioning_data.ssh_proxy, - service_port=service_port, - local_backend=provisioning_data.backend == BackendType.LOCAL, - bind_address=bind_address, - ) + ssh_attach: BaseSSHAttach + + if (jci := job.job_connection_info) is not None and jci.sshproxy_hostname is not None: + assert jci.sshproxy_upstream_id is not None + ssh_attach = SSHProxyAttach( + hostname=jci.sshproxy_hostname, + port=jci.sshproxy_port, + upstream_id=jci.sshproxy_upstream_id, + identity_path=ssh_identity_file, + ports_lock=self._ports_lock, + run_name=name, + service_port=service_port, + bind_address=bind_address, + ) + else: + hostname = provisioning_data.hostname + assert hostname is not None + ssh_port = provisioning_data.ssh_port + assert ssh_port is not None + + runtime_data = latest_job_submission.job_runtime_data + + container_ssh_port = DSTACK_RUNNER_SSH_PORT + if runtime_data is not None and runtime_data.ports is not None: + container_ssh_port = runtime_data.ports.get( + container_ssh_port, container_ssh_port + ) + + if runtime_data is not None and runtime_data.username is not None: + container_user = runtime_data.username + elif job.job_spec.user is not None and job.job_spec.user.username is not None: + container_user = job.job_spec.user.username + else: + container_user = "root" + + ssh_attach = SSHAttach( + hostname=hostname, + ssh_port=ssh_port, + container_ssh_port=container_ssh_port, + user=provisioning_data.username, + container_user=container_user, + identity_path=ssh_identity_file, + ports_lock=self._ports_lock, + run_name=name, + dockerized=provisioning_data.dockerized, + ssh_proxy=provisioning_data.ssh_proxy, + service_port=service_port, + local_backend=provisioning_data.backend == BackendType.LOCAL, + bind_address=bind_address, + ) + + self._ssh_attach = ssh_attach if not ports_lock: self._ssh_attach.attach() self._ports_lock = None diff --git a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py index 5f113a142..e4469a66a 100644 --- a/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/pipeline_tasks/test_running_jobs.py @@ -545,9 +545,90 @@ async def test_runs_provisioning_job( assert job_runtime_data.working_dir == "/dstack/run" assert job_runtime_data.username == "dstack" + @pytest.mark.parametrize("sshproxy_enforced", [False, True]) + async def test_provisioning_shim( + self, + monkeypatch: pytest.MonkeyPatch, + test_db, + session: AsyncSession, + worker: JobRunningWorker, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + sshproxy_enforced: bool, + ): + monkeypatch.setattr( + "dstack._internal.server.settings.SSHPROXY_ENFORCED", sshproxy_enforced + ) + project_ssh_pub_key = "__project_ssh_pub_key__" + project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) + user = await create_user(session=session) + repo = await create_repo(session=session, project_id=project.id) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, project=project, status=InstanceStatus.BUSY + ) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as py_version: + py_version.return_value = "3.13" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + submitted_at=get_current_datetime(), + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + + await _process_job(session, worker, job) + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit_task.assert_called_once_with( + task_id=job.id, + name="test-run-0-0", + registry_username="", + registry_password="", + image_name=( + f"dstackai/base:{settings.DSTACK_BASE_IMAGE_VERSION}-" + f"base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}" + ), + container_user="root", + privileged=False, + gpu=None, + cpu=None, + memory=None, + shm_size=None, + network_mode=NetworkMode.HOST, + volumes=[], + volume_mounts=[], + instance_mounts=[], + gpu_devices=[], + host_ssh_user="" if sshproxy_enforced else "ubuntu", + host_ssh_keys=[] if sshproxy_enforced else ["user_ssh_key"], + container_ssh_keys=[project_ssh_pub_key] + if sshproxy_enforced + else [project_ssh_pub_key, "user_ssh_key"], + instance_id=job_provisioning_data.instance_id, + ) + await session.refresh(job) + assert job.status == JobStatus.PULLING + @pytest.mark.parametrize("privileged", [False, True]) async def test_provisioning_shim_with_volumes( self, + monkeypatch: pytest.MonkeyPatch, test_db, session: AsyncSession, worker: JobRunningWorker, @@ -555,6 +636,7 @@ async def test_provisioning_shim_with_volumes( shim_client_mock: Mock, privileged: bool, ): + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_ENFORCED", False) project_ssh_pub_key = "__project_ssh_pub_key__" project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) user = await create_user(session=session) diff --git a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py index 66b38f331..0e2a63162 100644 --- a/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py +++ b/src/tests/_internal/server/background/scheduled_tasks/test_running_jobs.py @@ -301,17 +301,101 @@ async def test_running_job_terminates_when_done_by_runner( assert job.exit_status == 0 assert job.runner_timestamp == 2 + @pytest.mark.asyncio + @pytest.mark.parametrize("sshproxy_enforced", [False, True]) + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_provisioning_shim( + self, + monkeypatch: pytest.MonkeyPatch, + test_db, + session: AsyncSession, + ssh_tunnel_mock: Mock, + shim_client_mock: Mock, + sshproxy_enforced: bool, + ): + monkeypatch.setattr( + "dstack._internal.server.settings.SSHPROXY_ENFORCED", sshproxy_enforced + ) + project_ssh_pub_key = "__project_ssh_pub_key__" + project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) + user = await create_user(session=session) + repo = await create_repo( + session=session, + project_id=project.id, + ) + run_spec = get_run_spec(run_name="test-run", repo_id=repo.name) + run = await create_run( + session=session, + project=project, + repo=repo, + user=user, + run_name="test-run", + run_spec=run_spec, + ) + instance = await create_instance( + session=session, + project=project, + status=InstanceStatus.BUSY, + ) + job_provisioning_data = get_job_provisioning_data(dockerized=True) + + with patch( + "dstack._internal.server.services.jobs.configurators.base.get_default_python_verison" + ) as PyVersion: + PyVersion.return_value = "3.13" + job = await create_job( + session=session, + run=run, + status=JobStatus.PROVISIONING, + job_provisioning_data=job_provisioning_data, + instance=instance, + instance_assigned=True, + ) + + await process_running_jobs() + + ssh_tunnel_mock.assert_called_once() + shim_client_mock.healthcheck.assert_called_once() + shim_client_mock.submit_task.assert_called_once_with( + task_id=job.id, + name="test-run-0-0", + registry_username="", + registry_password="", + image_name=f"dstackai/base:{settings.DSTACK_BASE_IMAGE_VERSION}-base-ubuntu{settings.DSTACK_BASE_IMAGE_UBUNTU_VERSION}", + container_user="root", + privileged=False, + gpu=None, + cpu=None, + memory=None, + shm_size=None, + network_mode=NetworkMode.HOST, + volumes=[], + volume_mounts=[], + instance_mounts=[], + gpu_devices=[], + host_ssh_user="" if sshproxy_enforced else "ubuntu", + host_ssh_keys=[] if sshproxy_enforced else ["user_ssh_key"], + container_ssh_keys=[project_ssh_pub_key] + if sshproxy_enforced + else [project_ssh_pub_key, "user_ssh_key"], + instance_id=job_provisioning_data.instance_id, + ) + await session.refresh(job) + assert job.status == JobStatus.PULLING + @pytest.mark.asyncio @pytest.mark.parametrize("privileged", [False, True]) @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) async def test_provisioning_shim_with_volumes( self, + monkeypatch: pytest.MonkeyPatch, test_db, session: AsyncSession, ssh_tunnel_mock: Mock, shim_client_mock: Mock, privileged: bool, ): + monkeypatch.setattr("dstack._internal.server.settings.SSHPROXY_ENFORCED", False) project_ssh_pub_key = "__project_ssh_pub_key__" project = await create_project(session=session, ssh_public_key=project_ssh_pub_key) user = await create_user(session=session) diff --git a/src/tests/_internal/server/routers/test_runs.py b/src/tests/_internal/server/routers/test_runs.py index 31b4813aa..a30c3e30e 100644 --- a/src/tests/_internal/server/routers/test_runs.py +++ b/src/tests/_internal/server/routers/test_runs.py @@ -1189,12 +1189,15 @@ async def test_returns_run_with_job_connection_info_dev_environment( "proxied_ssh_command": ["ssh", f"{job.id.hex}@example.com", "-p", "2222"] if sshproxy else None, + "sshproxy_hostname": "example.com" if sshproxy else None, + "sshproxy_port": 2222 if sshproxy else None, + "sshproxy_upstream_id": job.id.hex if sshproxy else None, } @pytest.mark.asyncio @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) - async def test_returns_run_with_job_connection_info_task( - self, monkeypatch: pytest.MonkeyPatch, test_db, session: AsyncSession, client: AsyncClient + async def test_returns_run_with_job_connection_info_multi_replica_multi_node_task( + self, test_db, session: AsyncSession, client: AsyncClient ): user = await create_user(session=session, global_role=GlobalRole.USER) project = await create_project(session=session, owner=user) @@ -1242,6 +1245,9 @@ async def test_returns_run_with_job_connection_info_task( "attached_ide_url": None, "proxied_ide_url": None, "proxied_ssh_command": None, + "sshproxy_hostname": None, + "sshproxy_port": None, + "sshproxy_upstream_id": None, } assert jobs[0]["job_connection_info"] == { "attached_ssh_command": ["ssh", "test-task"],