From 7b37afb3ee8542d2069c6917ff61d68ca582bb08 Mon Sep 17 00:00:00 2001 From: Andrey Cheptsov Date: Thu, 26 Mar 2026 10:58:10 +0100 Subject: [PATCH] [Azure] Add support for H100 NVL and H200 VM series; refactor instance creation methods to cleanup failed instances --- .../_internal/core/backends/azure/compute.py | 94 +++++++++++++++++-- 1 file changed, 86 insertions(+), 8 deletions(-) diff --git a/src/dstack/_internal/core/backends/azure/compute.py b/src/dstack/_internal/core/backends/azure/compute.py index 74e585d631..deefdb0b75 100644 --- a/src/dstack/_internal/core/backends/azure/compute.py +++ b/src/dstack/_internal/core/backends/azure/compute.py @@ -6,7 +6,7 @@ from typing import Dict, List, Optional, Tuple from azure.core.credentials import TokenCredential -from azure.core.exceptions import ResourceExistsError, ResourceNotFoundError +from azure.core.exceptions import HttpResponseError, ResourceExistsError, ResourceNotFoundError from azure.mgmt import compute as compute_mgmt from azure.mgmt import network as network_mgmt from azure.mgmt.compute.models import ( @@ -168,7 +168,7 @@ def create_instance( # TODO: Support custom availability_zones. # Currently, VMs are regional, which means they don't have zone info. - vm = _launch_instance( + vm = _create_instance_and_wait( compute_client=self._compute_client, subscription_id=self.config.subscription_id, location=location, @@ -272,7 +272,7 @@ def create_gateway( ) tags = azure_resources.filter_invalid_tags(tags) - vm = _launch_instance( + vm = _create_instance_and_wait( compute_client=self._compute_client, subscription_id=self.config.subscription_id, location=configuration.region, @@ -426,8 +426,10 @@ def get_image_name(self) -> str: r"ND(\d+)rs_v2", # NDv2-series [8xV100 32GB] r"NV(\d+)adm?s_A10_v5", # NVadsA10 v5-series [A10] r"NC(\d+)ads_A100_v4", # NC A100 v4-series [A100 80GB] + r"NC(\d+)adi?s_H100_v5", # NC H100 v5-series [H100 NVL 94GB] r"ND(\d+)asr_v4", # ND A100 v4-series [8xA100 40GB] r"ND(\d+)amsr_A100_v4", # NDm A100 v4-series [8xA100 80GB] + r"ND(\d+)isr_H200_v5", # ND H200 v5-series [8xH200 141GB] ] _SUPPORTED_VM_SERIES_PATTERN = ( "^Standard_(" + "|".join(f"({s})" for s in _SUPPORTED_VM_SERIES_PATTERNS) + ")$" @@ -508,7 +510,7 @@ def _get_gateway_image_ref() -> ImageReference: ) -def _launch_instance( +def _begin_create_instance( compute_client: compute_mgmt.ComputeManagementClient, subscription_id: str, location: str, @@ -529,7 +531,8 @@ def _launch_instance( allocate_public_ip: bool = True, network_resource_group: Optional[str] = None, tags: Optional[Dict[str, str]] = None, -) -> VirtualMachine: +): + """Starts VM creation and returns immediately. The VM is created asynchronously.""" if tags is None: tags = {} if network_resource_group is None: @@ -628,11 +631,79 @@ def _launch_instance( message = e.error.message if e.error.message is not None else "" raise NoCapacityError(message) raise e - vm = poller.result(timeout=600) + return poller + + +def _create_instance_and_wait( + compute_client: compute_mgmt.ComputeManagementClient, + subscription_id: str, + location: str, + resource_group: str, + network_security_group: str, + network: str, + subnet: str, + managed_identity_name: Optional[str], + managed_identity_resource_group: Optional[str], + image_reference: ImageReference, + vm_size: str, + instance_name: str, + user_data: str, + ssh_pub_keys: List[str], + spot: bool, + disk_size: int, + computer_name: str, + allocate_public_ip: bool = True, + network_resource_group: Optional[str] = None, + tags: Optional[Dict[str, str]] = None, +) -> VirtualMachine: + """Blocking version used for gateway provisioning where IP is needed immediately.""" + poller = _begin_create_instance( + compute_client=compute_client, + subscription_id=subscription_id, + location=location, + resource_group=resource_group, + network_security_group=network_security_group, + network=network, + subnet=subnet, + managed_identity_name=managed_identity_name, + managed_identity_resource_group=managed_identity_resource_group, + image_reference=image_reference, + vm_size=vm_size, + instance_name=instance_name, + user_data=user_data, + ssh_pub_keys=ssh_pub_keys, + spot=spot, + disk_size=disk_size, + computer_name=computer_name, + allocate_public_ip=allocate_public_ip, + network_resource_group=network_resource_group, + tags=tags, + ) + try: + vm = poller.result(timeout=600) + except HttpResponseError as e: + # Azure may create a VM resource even when provisioning fails (e.g., AllocationFailed). + # Clean it up to avoid orphan VMs. + logger.warning( + "Instance %s provisioning failed: %s. Cleaning up.", + instance_name, + repr(e), + ) + _terminate_instance( + compute_client=compute_client, + resource_group=resource_group, + instance_name=instance_name, + ) + if e.error is not None and e.error.code in ( + "AllocationFailed", + "OverconstrainedAllocationRequest", + ): + raise NoCapacityError(e.error.message or str(e)) + raise if not poller.done(): logger.error( - "Timed out waiting for instance {instance_name} launch. " - "The instance will be terminated." + "Timed out waiting for instance %s launch. The instance will be terminated.", + instance_name, ) _terminate_instance( compute_client=compute_client, @@ -640,6 +711,13 @@ def _launch_instance( instance_name=instance_name, ) raise ComputeError(f"Timed out waiting for instance {instance_name} launch") + if (vm.provisioning_state or "").lower() == "failed": + _terminate_instance( + compute_client=compute_client, + resource_group=resource_group, + instance_name=instance_name, + ) + raise NoCapacityError(f"VM {instance_name} provisioning failed") return vm