Inference | Per-block MoE routing storage for prefix caching#4301
Inference | Per-block MoE routing storage for prefix caching#4301lmcafee-nvidia wants to merge 3 commits intoNVIDIA:mainfrom
Conversation
Convert MoE routing indices from GPU tensors to CPU numpy arrays after the forward pass, and add chunk-based accumulation infrastructure. This is the minimum subset of siddharth/support-nemo-rl-router-replay needed to support per-block routing storage. Changes: - _router_record_bookkeeping returns Dict[int, np.ndarray] (CPU numpy) - DynamicInferenceRequest.routing_indices changed to np.ndarray - add_routing_indices/finalize_routing_chunks for O(1) chunk staging - ndarray serialization/deserialization support - Engine uses chunk-based accumulation instead of torch.cat
…ility Move routing indices from per-request step-by-step accumulation to per-block storage on KVBlockAllocator. At request completion, routing is reconstructed by concatenating per-block routing in block order. Matched (prefix-cached) blocks retain routing from the original request, so reconstruction naturally covers all tokens including skipped prefixes. Key methods: - store_routing_per_block: scatters routing into per-block storage - reconstruct_routing_from_blocks: reassembles routing on completion - store_block_routing / get_block_routing: low-level block storage Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
|
This PR has been automatically converted to draft because all PRs must start as drafts. When you are ready for review, click Ready for Review to begin the review process. This will:
See the contribution guide for more details. |
| ] | ||
| if not routing_parts: | ||
| return | ||
| flat_routing = np.concatenate(routing_parts, axis=0) # [token_count, num_layers, topk] |
There was a problem hiding this comment.
It seems like this is the reverse of what we are doing in the text_generation_controller's router_record_bookkeeping function. Over there we get the map as [token_count, num_layers, topk] and convert it into a per request map. Here we are doing the opposite conversion. Do you think we could simply return the per token map from router_record_bookeeping and simplify things?
|
We should not store the routing indices in both the block store as well inside each request, given how much memory they can use. I suggest we only use the block store to store them, and once a request is finished, we reconstruct the entire data from it's blocks and send it back to the coordinator. |
| request.routing_indices = torch.cat( | ||
| [request.routing_indices, step_routing], dim=0 | ||
| ) | ||
| request.add_routing_indices(step_routing) |
There was a problem hiding this comment.
Can't we get rid of this storage now? Let the block store be the only place where we store the routing indices.
|
@lmcafee-nvidia this wont work with sequence parallel |
The CUDA-graph static buffer path in get_routing_indices() may return a tensor sliced to active_token_count (global unpadded), which can exceed the per-rank valid count under sequence parallelism. Truncate to padded_active_token_count // tp_size before the all-gather so only valid routing data is collected. Co-Authored-By: Claude Opus 4.6 (1M context) <noreply@anthropic.com>
Summary
KVBlockAllocatorso routing indices are scattered into KV cache blocks and reconstructed at request completion — prefix-cached blocks retain routing from the original requestadd_routing_indices/finalize_routing_chunks) and numpy serialization support onDynamicInferenceRequestTest plan
TestPerBlockRoutingtests (store/get round-trip, cleared on allocate/reset, persists through deregister, reconstruct from blocks, missing block returns None, survives LRU prefix match)🤖 Generated with Claude Code