Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 23 additions & 0 deletions src/dstack/_internal/core/models/runs.py
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,29 @@ class JobConnectionInfo(CoreModel):
)
),
]
sshproxy_hostname: Annotated[
Optional[str],
Field(description="sshproxy hostname. Not set if sshproxy is not configured."),
] = None
sshproxy_port: Annotated[
Optional[int],
Field(
description=(
"ssproxy port. Not set if sshproxy is not configured."
" May be not set if it is equal to the default SSH port 22."
)
),
] = None
sshproxy_upstream_id: Annotated[
Optional[str],
Field(
description=(
"sshproxy identifier for this job. SSH clients send this identifier as a username"
" to indicate which job they wish to connect."
" Not set if sshproxy is not configured."
)
),
] = None


class Job(CoreModel):
Expand Down
188 changes: 131 additions & 57 deletions src/dstack/_internal/core/services/ssh/attach.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,17 @@
_SSH_TUNNEL_REGEX = re.compile(r"(?:[\w.-]+:)?(?P<local_port>\d+):localhost:(?P<remote_port>\d+)")


class SSHAttach:
class BaseSSHAttach:
"""
A base class for SSH attach implementations.

Child classes must populate `self.hosts` inside overridden `__init__()` with at least one host
named as a `run_name` argument value.
"""

@classmethod
def get_control_sock_path(cls, run_name: str) -> Path:
return ConfigManager().dstack_ssh_dir / f"%r@{run_name}.control.sock"
return ConfigManager().dstack_ssh_dir / f"{run_name}.control.sock"

@classmethod
def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]:
Expand All @@ -57,21 +64,16 @@ def reuse_ports_lock(cls, run_name: str) -> Optional[PortsLock]:

def __init__(
self,
hostname: str,
ssh_port: int,
container_ssh_port: int,
user: str,
container_user: str,
id_rsa_path: PathLike,
ports_lock: PortsLock,
*,
run_name: str,
dockerized: bool,
ssh_proxy: Optional[SSHConnectionParams] = None,
identity_path: PathLike,
ports_lock: PortsLock,
destination: str,
service_port: Optional[int] = None,
local_backend: bool = False,
bind_address: Optional[str] = None,
):
self._attached = False
self._hosts_added_to_ssh_config = False
self._ports_lock = ports_lock
self.ports = ports_lock.dict()
self.run_name = run_name
Expand All @@ -80,9 +82,9 @@ def __init__(
# Cast all path-like values used in configs to FilePath instances for automatic
# path normalization in :func:`update_ssh_config`.
self.control_sock_path = FilePath(control_sock_path)
self.identity_file = FilePath(id_rsa_path)
self.identity_file = FilePath(identity_path)
self.tunnel = SSHTunnel(
destination=f"root@{run_name}",
destination=destination,
identity=self.identity_file,
forwarded_sockets=ports_to_forwarded_sockets(
ports=self.ports,
Expand All @@ -94,12 +96,92 @@ def __init__(
"ExitOnForwardFailure": "yes",
},
)
self.ssh_proxy = ssh_proxy
self.service_port = service_port
self.hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}

def __enter__(self):
self.attach()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.detach()

def attach(self):
include_ssh_config(self.ssh_config_path)
self._add_hosts_to_ssh_config()

hosts: dict[str, dict[str, Union[str, int, FilePath]]] = {}
self.hosts = hosts
self._ports_lock.release()

max_retries = 10
for i in range(max_retries):
try:
self.tunnel.open()
self._attached = True
atexit.register(self.detach)
return
except SSHError:
if i < max_retries - 1:
time.sleep(1)
self._remove_hosts_from_ssh_config()
raise SSHError("Can't connect to the remote host")

def detach(self):
self._remove_hosts_from_ssh_config()
if not self._attached:
logger.debug("Not attached")
return
self.tunnel.close()
self._attached = False
logger.debug("Detached")

def _add_hosts_to_ssh_config(self):
if self._hosts_added_to_ssh_config:
return
for host, options in self.hosts.items():
update_ssh_config(self.ssh_config_path, host, options)
self._hosts_added_to_ssh_config = True

def _remove_hosts_from_ssh_config(self):
if not self._hosts_added_to_ssh_config:
return
for host in self.hosts:
update_ssh_config(self.ssh_config_path, host, {})
self._hosts_added_to_ssh_config = False


class SSHAttach(BaseSSHAttach):
"""
`SSHAttach` attaches to a job directly, via a backend-specific chain of hosts.

Used when `dstack-sshproxy` is not configured on the server.
"""

def __init__(
self,
*,
run_name: str,
identity_path: PathLike,
ports_lock: PortsLock,
hostname: str,
ssh_port: int,
container_ssh_port: int,
user: str,
container_user: str,
dockerized: bool,
ssh_proxy: Optional[SSHConnectionParams] = None,
local_backend: bool = False,
service_port: Optional[int] = None,
bind_address: Optional[str] = None,
):
super().__init__(
run_name=run_name,
identity_path=identity_path,
ports_lock=ports_lock,
destination=f"root@{run_name}",
service_port=service_port,
bind_address=bind_address,
)
hosts = self.hosts
if local_backend:
hosts[run_name] = {
"HostName": hostname,
Expand Down Expand Up @@ -195,47 +277,39 @@ def __init__(
"StrictHostKeyChecking": "no",
"UserKnownHostsFile": "/dev/null",
}
if get_ssh_client_info().supports_multiplexing:
hosts[run_name].update(
{
"ControlMaster": "auto",
"ControlPath": self.control_sock_path,
}
)

def attach(self):
include_ssh_config(self.ssh_config_path)
for host, options in self.hosts.items():
update_ssh_config(self.ssh_config_path, host, options)

max_retries = 10
self._ports_lock.release()
for i in range(max_retries):
try:
self.tunnel.open()
self._attached = True
atexit.register(self.detach)
break
except SSHError:
if i < max_retries - 1:
time.sleep(1)
else:
self.detach()
raise SSHError("Can't connect to the remote host")
class SSHProxyAttach(BaseSSHAttach):
"""
`SSHProxyAttach` attaches to a job via `dstack-sshproxy`.

def detach(self):
if not self._attached:
logger.debug("Not attached")
return
self.tunnel.close()
for host in self.hosts:
update_ssh_config(self.ssh_config_path, host, {})
self._attached = False
logger.debug("Detached")
Used when `dstack-sshproxy` is configured on the server.
"""

def __enter__(self):
self.attach()
return self

def __exit__(self, exc_type, exc_val, exc_tb):
self.detach()
def __init__(
self,
*,
run_name: str,
identity_path: PathLike,
ports_lock: PortsLock,
hostname: str,
upstream_id: str,
port: Optional[int] = None,
service_port: Optional[int] = None,
bind_address: Optional[str] = None,
):
super().__init__(
run_name=run_name,
identity_path=identity_path,
ports_lock=ports_lock,
destination=f"{upstream_id}_root@{run_name}",
service_port=service_port,
bind_address=bind_address,
)
self.hosts[run_name] = {
"HostName": hostname,
"Port": port or 22,
"User": upstream_id,
"IdentityFile": self.identity_file,
"IdentitiesOnly": "yes",
}
32 changes: 22 additions & 10 deletions src/dstack/_internal/core/services/ssh/tunnel.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import subprocess
import tempfile
from dataclasses import dataclass
from typing import Dict, Iterable, List, Literal, Optional, Union
from typing import Dict, Iterable, List, Literal, NoReturn, Optional, Union

from dstack._internal.core.errors import SSHError
from dstack._internal.core.models.instances import SSHConnectionParams
Expand Down Expand Up @@ -181,9 +181,8 @@ def open(self) -> None:
raise SSHError(msg) from e
if r.returncode == 0:
return
stderr = self._read_log_file()
logger.debug("SSH tunnel failed: %s", stderr)
raise get_ssh_error(stderr)
log_output = self._read_log_file()
self._raise_ssh_error_from_log_output(log_output)

async def aopen(self) -> None:
await run_async(self._remove_log_file)
Expand All @@ -199,9 +198,8 @@ async def aopen(self) -> None:
raise SSHError(msg) from e
if proc.returncode == 0:
return
stderr = await run_async(self._read_log_file)
logger.debug("SSH tunnel failed: %s", stderr)
raise get_ssh_error(stderr)
log_output = await run_async(self._read_log_file)
self._raise_ssh_error_from_log_output(log_output)

def close(self) -> None:
if not os.path.exists(self.control_sock_path):
Expand Down Expand Up @@ -299,9 +297,13 @@ def _build_proxy_command(
]
return "ProxyCommand=" + shlex.join(command)

def _read_log_file(self) -> bytes:
with open(self.log_path, "rb") as f:
return f.read()
def _read_log_file(self) -> Optional[bytes]:
try:
with open(self.log_path, "rb") as f:
return f.read()
except OSError as e:
logger.debug("Failed to read SSH tunnel log file %s: %s", self.log_path, e)
return None

def _remove_log_file(self) -> None:
try:
Expand All @@ -311,6 +313,16 @@ def _remove_log_file(self) -> None:
except OSError as e:
logger.debug("Failed to remove SSH tunnel log file %s: %s", self.log_path, e)

def _raise_ssh_error_from_log_output(self, output: Optional[bytes]) -> NoReturn:
if output is None:
msg = "(no log file)"
ssh_error = SSHError()
else:
msg = output
ssh_error = get_ssh_error(output)
logger.debug("SSH tunnel failed: %s", msg)
raise ssh_error

def _get_identity_path(self, identity: FilePathOrContent, tmp_filename: str) -> PathLike:
if isinstance(identity, FilePath):
return identity.path
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
RunStatus,
)
from dstack._internal.core.models.volumes import InstanceMountPoint, Volume, VolumeMountPoint
from dstack._internal.server import settings as server_settings
from dstack._internal.server.background.pipeline_tasks.base import (
Fetcher,
Heartbeater,
Expand Down Expand Up @@ -525,12 +526,16 @@ async def _process_provisioning_status(
fmt(context.job_model),
context.job_submission.age,
)
ssh_user = job_provisioning_data.username
assert context.run.run_spec.ssh_key_pub is not None
user_ssh_key = context.run.run_spec.ssh_key_pub.strip()
public_keys = [context.project.ssh_public_key.strip(), user_ssh_key]
public_keys = [context.project.ssh_public_key.strip()]
ssh_user: Optional[str] = None
user_ssh_key: Optional[str] = None
if not server_settings.SSHPROXY_ENFORCED:
ssh_user = job_provisioning_data.username
assert context.run.run_spec.ssh_key_pub is not None
user_ssh_key = context.run.run_spec.ssh_key_pub.strip()
public_keys.append(user_ssh_key)
if job_provisioning_data.backend == BackendType.LOCAL:
user_ssh_key = ""
user_ssh_key = None
success = await run_async(
_process_provisioning_with_shim,
server_ssh_private_keys,
Expand Down Expand Up @@ -1065,8 +1070,8 @@ def _process_provisioning_with_shim(
volumes: list[Volume],
registry_auth: Optional[RegistryAuth],
public_keys: list[str],
ssh_user: str,
ssh_key: str,
ssh_user: Optional[str],
ssh_key: Optional[str],
) -> bool:
job_spec = JobSpec.__response__.parse_raw(job_model.job_spec_data)
shim_client = client.ShimClient(port=ports[DSTACK_SHIM_HTTP_PORT])
Expand Down Expand Up @@ -1128,7 +1133,7 @@ def _process_provisioning_with_shim(
volume_mounts=volume_mounts,
instance_mounts=instance_mounts,
gpu_devices=gpu_devices,
host_ssh_user=ssh_user,
host_ssh_user=ssh_user or "",
host_ssh_keys=[ssh_key] if ssh_key else [],
container_ssh_keys=public_keys,
instance_id=jpd.instance_id,
Expand All @@ -1143,8 +1148,8 @@ def _process_provisioning_with_shim(
container_user=container_user,
shm_size=job_spec.requirements.resources.shm_size,
public_keys=public_keys,
ssh_user=ssh_user,
ssh_key=ssh_key,
ssh_user=ssh_user or "",
ssh_key=ssh_key or "",
mounts=volume_mounts,
volumes=volumes,
instance_mounts=instance_mounts,
Expand Down
Loading
Loading