Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 86 additions & 8 deletions src/dstack/_internal/core/backends/azure/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
from typing import Dict, List, Optional, Tuple

from azure.core.credentials import TokenCredential
from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError
from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError
from azure.mgmt import compute as compute_mgmt
from azure.mgmt import network as network_mgmt
from azure.mgmt.compute.models import (
Expand Down Expand Up @@ -168,7 +168,7 @@ def create_instance(

# TODO: Support custom availability_zones.
# Currently, VMs are regional, which means they don't have zone info.
vm = _launch_instance(
vm = _create_instance_and_wait(
compute_client=self._compute_client,
subscription_id=self.config.subscription_id,
location=location,
Expand Down Expand Up @@ -272,7 +272,7 @@ def create_gateway(
)
tags = azure_resources.filter_invalid_tags(tags)

vm = _launch_instance(
vm = _create_instance_and_wait(
compute_client=self._compute_client,
subscription_id=self.config.subscription_id,
location=configuration.region,
Expand Down Expand Up @@ -426,8 +426,10 @@ def get_image_name(self) -> str:
r"ND(\d+)rs_v2", # NDv2-series [8xV100 32GB]
r"NV(\d+)adm?s_A10_v5", # NVadsA10 v5-series [A10]
r"NC(\d+)ads_A100_v4", # NC A100 v4-series [A100 80GB]
r"NC(\d+)adi?s_H100_v5", # NC H100 v5-series [H100 NVL 94GB]
r"ND(\d+)asr_v4", # ND A100 v4-series [8xA100 40GB]
r"ND(\d+)amsr_A100_v4", # NDm A100 v4-series [8xA100 80GB]
r"ND(\d+)isr_H200_v5", # ND H200 v5-series [8xH200 141GB]
]
_SUPPORTED_VM_SERIES_PATTERN = (
"^Standard_(" + "|".join(f"({s})" for s in _SUPPORTED_VM_SERIES_PATTERNS) + ")$"
Expand Down Expand Up @@ -508,7 +510,7 @@ def _get_gateway_image_ref() -> ImageReference:
)


def _launch_instance(
def _begin_create_instance(
compute_client: compute_mgmt.ComputeManagementClient,
subscription_id: str,
location: str,
Expand All @@ -529,7 +531,8 @@ def _launch_instance(
allocate_public_ip: bool = True,
network_resource_group: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> VirtualMachine:
):
"""Starts VM creation and returns immediately. The VM is created asynchronously."""
if tags is None:
tags = {}
if network_resource_group is None:
Expand Down Expand Up @@ -628,18 +631,93 @@ def _launch_instance(
message = e.error.message if e.error.message is not None else ""
raise NoCapacityError(message)
raise e
vm = poller.result(timeout=600)
return poller


def _create_instance_and_wait(
compute_client: compute_mgmt.ComputeManagementClient,
subscription_id: str,
location: str,
resource_group: str,
network_security_group: str,
network: str,
subnet: str,
managed_identity_name: Optional[str],
managed_identity_resource_group: Optional[str],
image_reference: ImageReference,
vm_size: str,
instance_name: str,
user_data: str,
ssh_pub_keys: List[str],
spot: bool,
disk_size: int,
computer_name: str,
allocate_public_ip: bool = True,
network_resource_group: Optional[str] = None,
tags: Optional[Dict[str, str]] = None,
) -> VirtualMachine:
"""Blocking version used for gateway provisioning where IP is needed immediately."""
poller = _begin_create_instance(
compute_client=compute_client,
subscription_id=subscription_id,
location=location,
resource_group=resource_group,
network_security_group=network_security_group,
network=network,
subnet=subnet,
managed_identity_name=managed_identity_name,
managed_identity_resource_group=managed_identity_resource_group,
image_reference=image_reference,
vm_size=vm_size,
instance_name=instance_name,
user_data=user_data,
ssh_pub_keys=ssh_pub_keys,
spot=spot,
disk_size=disk_size,
computer_name=computer_name,
allocate_public_ip=allocate_public_ip,
network_resource_group=network_resource_group,
tags=tags,
)
try:
vm = poller.result(timeout=600)
except HttpResponseError as e:
# Azure may create a VM resource even when provisioning fails (e.g., AllocationFailed).
# Clean it up to avoid orphan VMs.
logger.warning(
"Instance %s provisioning failed: %s. Cleaning up.",
instance_name,
repr(e),
)
_terminate_instance(
compute_client=compute_client,
resource_group=resource_group,
instance_name=instance_name,
)
if e.error is not None and e.error.code in (
"AllocationFailed",
"OverconstrainedAllocationRequest",
):
raise NoCapacityError(e.error.message or str(e))
raise
if not poller.done():
logger.error(
"Timed out waiting for instance {instance_name} launch. "
"The instance will be terminated."
"Timed out waiting for instance %s launch. The instance will be terminated.",
instance_name,
)
_terminate_instance(
compute_client=compute_client,
resource_group=resource_group,
instance_name=instance_name,
)
raise ComputeError(f"Timed out waiting for instance {instance_name} launch")
if (vm.provisioning_state or "").lower() == "failed":
_terminate_instance(
compute_client=compute_client,
resource_group=resource_group,
instance_name=instance_name,
)
raise NoCapacityError(f"VM {instance_name} provisioning failed")
return vm


Expand Down
Loading