Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
337 changes: 337 additions & 0 deletions docker/patch/latest/sglang_delta_compression.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,337 @@
diff -urN tmp_sglang_orig/python/sglang/srt/managers/io_struct.py tmp_sglang_mod/python/sglang/srt/managers/io_struct.py
--- tmp_sglang_orig/python/sglang/srt/managers/io_struct.py 2026-04-07 01:29:33.170989783 +0000
+++ tmp_sglang_mod/python/sglang/srt/managers/io_struct.py 2026-04-07 01:30:22.328323959 +0000
@@ -1362,6 +1362,7 @@
names: List[str]
dtypes: List[str]
shapes: List[List[int]]
+ sparse_metadata: Optional[List[Dict[str, Any]]] = None
# The group name
group_name: str = "weight_update_group"
# Whether to flush the cache after updating weights
diff -urN tmp_sglang_orig/python/sglang/srt/managers/tp_worker.py tmp_sglang_mod/python/sglang/srt/managers/tp_worker.py
--- tmp_sglang_orig/python/sglang/srt/managers/tp_worker.py 2026-04-07 01:29:33.368983074 +0000
+++ tmp_sglang_mod/python/sglang/srt/managers/tp_worker.py 2026-04-07 01:30:32.446981062 +0000
@@ -148,6 +148,7 @@
recv_req.names,
recv_req.dtypes,
recv_req.shapes,
+ recv_req.sparse_metadata,
recv_req.group_name,
recv_req.load_format,
)
diff -urN tmp_sglang_orig/python/sglang/srt/model_executor/model_runner.py tmp_sglang_mod/python/sglang/srt/model_executor/model_runner.py
--- tmp_sglang_orig/python/sglang/srt/model_executor/model_runner.py 2026-04-07 01:29:33.592975483 +0000
+++ tmp_sglang_mod/python/sglang/srt/model_executor/model_runner.py 2026-04-07 01:31:33.304918738 +0000
@@ -13,6 +13,7 @@
# ==============================================================================
"""ModelRunner runs the forward passes of the models."""

+import contextlib
import datetime
import gc
import inspect
@@ -247,6 +248,9 @@

logger = logging.getLogger(__name__)

+_ORIGINAL_TENSOR_COPY = torch.Tensor.copy_
+_ORIGINAL_TENSOR_FILL = torch.Tensor.fill_
+

def resolve_language_model(model: nn.Module) -> nn.Module:
model_cls_name = model.__class__.__name__
@@ -1341,6 +1345,7 @@
names,
dtypes,
shapes,
+ sparse_metadata,
group_name,
load_format: Optional[str] = None,
):
@@ -1363,6 +1368,18 @@
return self._update_bucketed_weights_from_distributed(
names, dtypes, shapes, group_name
)
+ if load_format == "distributed_delta_sparse_indices":
+ return self._apply_sparse_delta_weights_from_distributed(
+ dtypes, shapes, sparse_metadata, group_name, transport="sparse_indices"
+ )
+ if load_format == "distributed_delta_sparse_bitmask":
+ return self._apply_sparse_delta_weights_from_distributed(
+ dtypes, shapes, sparse_metadata, group_name, transport="sparse_bitmask"
+ )
+ if load_format == "distributed_delta":
+ return self._apply_delta_weights_from_distributed(
+ names, dtypes, shapes, group_name
+ )
try:
weights = []
handles = []
@@ -1395,6 +1412,151 @@
logger.error(error_msg)
return False, error_msg

+ def _apply_delta_weights_from_distributed(
+ self, names, dtypes, shapes, group_name
+ ):
+ try:
+ weights = []
+ handles = []
+ for name, dtype, shape in zip(names, dtypes, shapes):
+ target_dtype = (
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ )
+ weight = torch.empty(shape, dtype=target_dtype, device=self.device)
+ handles.append(
+ torch.distributed.broadcast(
+ weight,
+ src=0,
+ group=self._model_update_group[group_name],
+ async_op=True,
+ )
+ )
+ weights.append((name, weight))
+ for handle in handles:
+ handle.wait()
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ self.model.load_weights(weights)
+ return True, "Succeeded to apply weight deltas online."
+
+ except Exception as e:
+ error_msg = (
+ f"Failed to apply weight deltas online: {e}. "
+ f"The model weights may be in an inconsistent state. "
+ f"Please discard the whole weights."
+ )
+ logger.error(error_msg)
+ return False, error_msg
+
+ def _apply_sparse_delta_weights_from_distributed(
+ self, dtypes, shapes, sparse_metadata, group_name, transport
+ ):
+ # 512 MiB batch cap for load_weights amortization. Each load_weights call
+ # costs ~2ms in name resolution + MoE expert remapping, so batching reduces
+ # call count across ~6000 params. Sweep at GLM-4.7-355B H100 64-rollout:
+ # 96 MiB → 1110 calls, 37.8s avg delta sync
+ # 256 MiB → ~400 calls, ~32s
+ # 512 MiB → 128 calls, 30.3s ← chosen
+ # 1024 MiB → OOM on some engines (not enough scratch for decode)
+ _BATCH_BYTE_CAP = 512 * 1024 * 1024
+ try:
+ encoded_tensors = []
+ handles = []
+ for dtype, shape in zip(dtypes, shapes):
+ target_dtype = (
+ dtype if isinstance(dtype, torch.dtype) else getattr(torch, dtype)
+ )
+ encoded_tensor = torch.empty(shape, dtype=target_dtype, device=self.device)
+ handles.append(
+ torch.distributed.broadcast(
+ encoded_tensor,
+ src=0,
+ group=self._model_update_group[group_name],
+ async_op=True,
+ )
+ )
+ encoded_tensors.append(encoded_tensor)
+ for handle in handles:
+ handle.wait()
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ batch = []
+ batch_bytes = 0
+
+ def _flush_batch():
+ nonlocal batch, batch_bytes
+ if not batch:
+ return
+ self.model.load_weights(batch)
+ batch = []
+ batch_bytes = 0
+
+ if transport == "sparse_indices":
+ packed_indices = encoded_tensors[0].to(dtype=torch.long)
+ packed_values = encoded_tensors[1]
+ for meta in sparse_metadata:
+ target_dtype = getattr(torch, meta["dtype"])
+ target_shape = tuple(meta["shape"])
+ numel = int(meta["numel"])
+ decoded_flat = torch.zeros(
+ numel, dtype=target_dtype, device=self.device
+ )
+ index_start = int(meta["index_start"])
+ index_end = int(meta["index_end"])
+ value_start = int(meta["value_start"])
+ value_end = int(meta["value_end"])
+ if index_end > index_start:
+ decoded_flat.index_copy_(
+ 0,
+ packed_indices[index_start:index_end],
+ packed_values[value_start:value_end],
+ )
+ decoded_weight = decoded_flat.view(target_shape)
+ tensor_bytes = numel * decoded_flat.element_size()
+ if batch_bytes + tensor_bytes > _BATCH_BYTE_CAP and batch:
+ _flush_batch()
+ batch.append((meta["name"], decoded_weight))
+ batch_bytes += tensor_bytes
+ _flush_batch()
+ elif transport == "sparse_bitmask":
+ packed_masks = encoded_tensors[0]
+ packed_values = encoded_tensors[1]
+ for meta in sparse_metadata:
+ target_dtype = getattr(torch, meta["dtype"])
+ target_shape = tuple(meta["shape"])
+ numel = int(meta["numel"])
+ decoded_flat = torch.zeros(
+ numel, dtype=target_dtype, device=self.device
+ )
+ mask_start = int(meta["mask_start"])
+ mask_end = int(meta["mask_end"])
+ value_start = int(meta["value_start"])
+ value_end = int(meta["value_end"])
+ unpacked_mask = _unpack_bitmask(
+ packed_masks[mask_start:mask_end], numel, self.device
+ )
+ decoded_flat[unpacked_mask] = packed_values[value_start:value_end]
+ decoded_weight = decoded_flat.view(target_shape)
+ tensor_bytes = numel * decoded_flat.element_size()
+ if batch_bytes + tensor_bytes > _BATCH_BYTE_CAP and batch:
+ _flush_batch()
+ batch.append((meta["name"], decoded_weight))
+ batch_bytes += tensor_bytes
+ _flush_batch()
+ else:
+ raise ValueError(
+ f"Unsupported sparse delta transport: {transport}"
+ )
+ return True, "Succeeded to apply sparse weight deltas online."
+ except Exception as e:
+ error_msg = (
+ f"Failed to apply sparse weight deltas online: {e}. "
+ f"The model weights may be in an inconsistent state. "
+ f"Please discard the whole weights."
+ )
+ logger.error(error_msg)
+ return False, error_msg
+
def _update_bucketed_weights_from_distributed(
self, names, dtypes, shapes, group_name
):
@@ -1437,6 +1599,10 @@
return self._update_weights_from_flattened_bucket(
flattened_tensor_bucket_dict=named_tensors
)
+ if load_format == "flattened_bucket_delta":
+ return self._apply_weight_deltas_from_flattened_bucket(
+ flattened_tensor_bucket_dict=named_tensors
+ )

# We need to get device after patch otherwise the device would be wrong
self.device_module = torch.get_device_module(self.device)
@@ -1489,6 +1655,35 @@

return True, "Success"

+ def _apply_weight_deltas_from_flattened_bucket(
+ self,
+ flattened_tensor_bucket_dict,
+ ):
+ flattened_tensor = flattened_tensor_bucket_dict["flattened_tensor"]
+ metadata = flattened_tensor_bucket_dict["metadata"]
+
+ converted_metadata = []
+ for meta in metadata:
+ converted_meta = FlattenedTensorMetadata(
+ name=meta.name,
+ shape=meta.shape,
+ dtype=meta.dtype,
+ start_idx=meta.start_idx,
+ end_idx=meta.end_idx,
+ numel=meta.numel,
+ )
+ converted_metadata.append(converted_meta)
+
+ bucket = FlattenedTensorBucket(
+ flattened_tensor=flattened_tensor, metadata=converted_metadata
+ )
+ delta_tensors = bucket.reconstruct_tensors()
+
+ with _additive_weight_copy_context():
+ with _wrap_post_load_weights_with_original_copy_context(self.model):
+ self.model.load_weights(delta_tensors)
+ return True, "Success"
+
def get_weights_by_name(
self, name: str, truncate_size: int = 100
) -> Optional[torch.Tensor]:
@@ -2718,6 +2907,67 @@
return True, "Success"


+@contextlib.contextmanager
+def _restore_weight_copy_context():
+ current_copy = torch.Tensor.copy_
+ current_fill = torch.Tensor.fill_
+ torch.Tensor.copy_ = _ORIGINAL_TENSOR_COPY
+ torch.Tensor.fill_ = _ORIGINAL_TENSOR_FILL
+ try:
+ yield
+ finally:
+ torch.Tensor.copy_ = current_copy
+ torch.Tensor.fill_ = current_fill
+
+
+@contextlib.contextmanager
+def _wrap_post_load_weights_with_original_copy_context(model):
+ original_post_load = getattr(model, "post_load_weights", None)
+ if original_post_load is None:
+ yield
+ return
+
+ def wrapped_post_load_weights(*args, **kwargs):
+ with _restore_weight_copy_context():
+ return original_post_load(*args, **kwargs)
+
+ model.post_load_weights = wrapped_post_load_weights
+ try:
+ yield
+ finally:
+ model.post_load_weights = original_post_load
+
+
+@contextlib.contextmanager
+def _additive_weight_copy_context():
+ original_copy_ = torch.Tensor.copy_
+ original_fill_ = torch.Tensor.fill_
+
+ def _additive_copy_(self, src, non_blocking=False):
+ self.add_(src.to(device=self.device, dtype=self.dtype))
+ return self
+
+ def _additive_fill_(self, value):
+ self.add_(value)
+ return self
+
+ torch.Tensor.copy_ = _additive_copy_
+ torch.Tensor.fill_ = _additive_fill_
+ try:
+ yield
+ finally:
+ torch.Tensor.copy_ = original_copy_
+ torch.Tensor.fill_ = original_fill_
+
+
+def _unpack_bitmask(packed, numel, device):
+ if numel == 0:
+ return torch.empty(0, dtype=torch.bool, device=device)
+ shifts = torch.arange(8, dtype=torch.uint8, device=device)
+ expanded = ((packed.unsqueeze(1) >> shifts) & 1).reshape(-1)
+ return expanded[:numel].to(dtype=torch.bool)
+
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
params_dict = dict(model.named_parameters())
for name, tensor in named_tensors:
Loading
Loading