Skip to content

Feature/unswizzle#2732

Open
int-smart wants to merge 15 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle
Open

Feature/unswizzle#2732
int-smart wants to merge 15 commits intoNVIDIA:mainfrom
int-smart:feature/unswizzle

Conversation

@int-smart
Copy link

@int-smart int-smart commented Mar 4, 2026

Description

This PR adds unswizzle support for scaling factors and extends the swizzle module so scaling tensors can be converted from GEMM-swizzled layout back to compact layout, including multi-tensor paths. It also adds round-trip and standalone tests to validate unswizzle correctness.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Added unswizzle APIs and implementation in transformer_engine/common/swizzle/swizzle.cu and declarations in transformer_engine/common/include/transformer_engine/swizzle.h
  • Added multi-tensor unswizzle support with swizzle-like validation assumptions (homogeneous scaling mode/layout, swizzled input and compact output expectations)
  • Refactored multi-tensor unswizzle launch/kernels to mirror swizzle structure (split row-wise and column-wise kernels) for easier readability
  • Added/extended tests in tests/cpp/operator/test_swizzle.cu, including standalone unswizzle and swizzle→unswizzle round-trip coverage

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

int-smart and others added 6 commits March 3, 2026 20:40
- Introduced `nvte_unswizzle_scaling_factors` to convert swizzled scaling factors back to row-major format.
- Implemented `regs_unshuffle_with_bit_shifts` and `regs_unshuffle` for unshuffling operations in CUDA kernels.
- Added `unswizzle_row_scaling_kernel_impl` and `unswizzle_col_scaling_kernel_impl` for handling unswizzling in row and column scaling respectively.

These changes enhance the functionality of the swizzle module, enabling better handling of scaling factors in tensor operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
These enhancements tests the changes introduced for unswizzling

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `compute_ref_unswizzle` to handle the conversion of swizzled scaling factors back to their original format.
- Added `performTestUnswizzle1D` to validate the unswizzling process with various scaling modes.
- Created `UnswizzleTestSuite` for comprehensive testing of unswizzling operations.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Moved the definition of `swizzle_row_scaling_kernel` to a new location for better organization.
- Ensured the kernel implementation is now properly defined and accessible for scaling operations in the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
- Introduced `multi_tensor_unswizzle_scaling_factors` to convert swizzled scaling factors back to their original row-major format.
- Implemented CUDA kernels for unswizzling in both row and column scaling, enhancing the swizzle module's functionality.
- Updated the launch function to handle multiple tensor unswizzling operations efficiently.

These changes improve the handling of scaling factors in tensor operations, ensuring better performance and organization within the swizzle module.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 4, 2026

Greptile Summary

This PR introduces unswizzle support for MXFP8/NVFP4 scaling factors, adding single-tensor and multi-tensor nvte_unswizzle_scaling_factors / nvte_multi_tensor_unswizzle_scaling_factors APIs that reverse the GEMM-swizzled layout back to the compact row-major format. New inverse kernels (unswizzle_row_scaling_kernel_impl, unswizzle_col_scaling_kernel_impl) mirror the existing swizzle kernels, and a suite of standalone and round-trip tests validates correctness for both aligned and padded shapes.

Key concerns:

  • Missing input-size validation in unswizzle_scaling_factors (swizzle.cu lines ~1159–1200): The function derives m and k from output->scale_inv.shape, then asserts m * k == output->scale_inv.numel() — a check that is always true for a well-formed 2D tensor and thus provides no safety. Critically, input->scale_inv.numel() is never validated against m * k. This is asymmetric with swizzle_scaling_factors (which validates output->numel()) and with multi_tensor_unswizzle_scaling_factors (which derives m/k from input's shape and validates the output). A caller passing a swizzled input that was built from differently-padded dimensions will trigger silent OOB reads. The same gap exists in both the rowwise and columnwise MXFP8 and NVFP4 paths.

  • unswizzle_col_scaling_kernel_impl SLM load relies on undocumented layout invariant (line ~461): The contiguous read of SF_TILE_SIZE_I32 * k_tiles_in_tb int32s from input_i32[i] is correct only because the swizzle kernel writes exactly that many contiguous bytes for each (M-tile, K-tile-block) pair. This invariant should be documented or guarded with an assertion; the existing tests do not stress non-power-of-two K-tile counts large enough to expose any edge case in this region.

  • Dual-scale asymmetry (single-tensor and multi-tensor unswizzle both reject tensors with both rowwise and columnwise scale factors), which breaks the round-trip for tensors that carry both — this is an ongoing discussion in prior threads and remains unresolved.

  • Roundtrip test only exercises aligned shapes: performTestSwizzleUnswizzleRoundtrip uses num_tiles vectors that produce M/K exact multiples of 128, so the padded-tail correctness validated by the standalone unswizzle tests is not covered by the round-trip suite.

Confidence Score: 2/5

  • Not safe to merge: missing input-size validation in unswizzle_scaling_factors can lead to silent OOB reads, and the dual-scale asymmetry (confirmed by a senior developer to need fixing) remains unresolved.
  • The single-tensor unswizzle_scaling_factors skips any check that input->scale_inv.numel() == m * k (where m/k come from the output tensor), making it possible to silently operate with a mis-sized input buffer. The multi-tensor path handles this correctly, highlighting the inconsistency. Additionally, the dual-scale restriction that prevents a tensor with both rowwise and columnwise scales from being unswizzled was flagged by a senior developer as needing to be removed, but the current PR still retains it. Together these represent correctness and API-completeness issues that should be addressed before merging.
  • transformer_engine/common/swizzle/swizzle.cu — specifically the unswizzle_scaling_factors function (input size validation) and unswizzle_col_scaling_kernel_impl (SLM load invariant documentation).

Important Files Changed

Filename Overview
transformer_engine/common/swizzle/swizzle.cu Adds unswizzle kernels and host-side dispatch (single- and multi-tensor). Key issues: unswizzle_scaling_factors derives m/k from the output tensor shape but never validates that the input tensor has the same element count — the output-only check is tautological. The unswizzle_col_scaling_kernel_impl SLM load relies on a layout invariant inherited from the swizzle kernel that is not explicitly documented. The dual-scale restriction (`!has_rowwise
tests/cpp/operator/test_swizzle.cu Adds standalone unswizzle tests and round-trip tests. Previously reported UB (uninitialized SF_MODE_X/SF_MODE_Y) and missing spaces in skip messages appear addressed in the visible code. Roundtrip test compares only unpadded scale bytes (scale_shape[0]*scale_shape[1]), and all roundtrip shapes are exact multiples of 128, so padding-region correctness is not exercised in round-trips. The reference unswizzle implementation appears logically correct as the exact inverse of compute_ref_swizzle.
transformer_engine/common/include/transformer_engine/swizzle.h Adds nvte_unswizzle_scaling_factors and nvte_multi_tensor_unswizzle_scaling_factors declarations with matching docstrings. No issues found.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[nvte_unswizzle_scaling_factors] --> B[unswizzle_scaling_factors]
    B --> C{scaling_mode?}
    C -->|MXFP8_1D| D[Derive m,k from output->scale_inv.shape]
    C -->|NVFP4_1D| E[Derive m,k from output scale shape]
    D --> F{rowwise or columnwise?}
    E --> F
    F -->|rowwise| G[Validate m%128==0, k%4==0]
    F -->|columnwise| G
    G --> H{vec_load_size selection\n based on num_tiles}
    H --> I[Launch unswizzle_scaling_kernel\nrow or col path]
    I -->|row_scaling=true| J[unswizzle_row_scaling_kernel_impl\nSLM load → regs_unshuffle → store compact]
    I -->|row_scaling=false| K[unswizzle_col_scaling_kernel_impl\nSLM load → regs_unshuffle_with_bit_shifts → store compact]

    L[nvte_multi_tensor_unswizzle_scaling_factors] --> M[multi_tensor_unswizzle_scaling_factors]
    M --> N{rowwise_unswizzle?}
    N -->|yes| O[Batch tensors into MultiSwizzleArgs\nm,k from input shape\nvalidate output.numel==m*k]
    N -->|no| P[Batch columnwise tensors]
    O --> Q[launch_multi_tensor_unswizzle_scaling_factors\nrowwise=true]
    P --> R[launch_multi_tensor_unswizzle_scaling_factors\nrowwise=false]
    Q --> S[multi_tensor_unswizzle_row_scaling_kernel]
    R --> T[multi_tensor_unswizzle_col_scaling_kernel]

    style D fill:#ffcccc,stroke:#cc0000
    style E fill:#ffcccc,stroke:#cc0000
Loading

Last reviewed commit: "Refactor unswizzling..."

@vthumbe1503 vthumbe1503 added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Mar 4, 2026
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@int-smart int-smart force-pushed the feature/unswizzle branch from 85ea04b to 17dbb33 Compare March 5, 2026 02:13
int-smart and others added 2 commits March 4, 2026 18:49
@ptrendx
Copy link
Member

ptrendx commented Mar 11, 2026

@int-smart Please address the comments from Greptile and ideally also add the test case with the input not already padded to 128,128.

@int-smart
Copy link
Author

@ptrendx Will look into these

@int-smart
Copy link
Author

@ptrendx From what I am understanding then, there is no relevance of padding to the unswizzle kernel. Since the padding is already done during the swizzling operation I can just mirror it back to compact layout with the zero pads correctly in the compact layout and that should do. Is that assumption correct. Initially I was thinking of removing the padding from the scale_inv itself since this would be used for checkpointing

int-smart and others added 2 commits March 12, 2026 19:53
- Updated unswizzling kernel implementations to remove original_M and original_K parameters, simplifying the function signatures.
- Enhanced test suite to utilize new unswizzling data shapes, ensuring comprehensive coverage of aligned and padded cases.

These changes improve the clarity and efficiency of the unswizzling process in the swizzle module.
Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
@ptrendx
Copy link
Member

ptrendx commented Mar 16, 2026

@int-smart I'm not sure I follow, I think that what you are saying is probably correct, but let me try to clarify just in case:

  • the scaling factors, irrespective of the compact or gemm-ready layout, are zero-padded to the multiple of [128,4] (or the transpose in case of compact and columnwise).
  • So for the unswizzle, you should just use the same size of the output unswizzled tensor as the original swizzled one. You don't even need to zero it before unswizzling, since the swizzled tensor already has 0s in the right places so unswizzling it will put 0s in the pad positions.

@int-smart
Copy link
Author

@ptrendx Makes sense. I added that in the last commit.

Comment on lines +1150 to +1240
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING:
NVTE_CHECK(is_fp8_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP8, got ",
to_string(input->dtype()), ").");
break;
case NVTE_NVFP4_1D_SCALING:
NVTE_CHECK(is_fp4_dtype(input->dtype()), "Input tensor has invalid dtype (expected FP4, got ",
to_string(input->dtype()), ").");
break;
default:
NVTE_ERROR("Invalid scaling mode");
}

const bool has_rowwise_scale_inv = input->scale_inv.has_data();
const bool has_columnwise_scale_inv = input->columnwise_scale_inv.has_data();
NVTE_CHECK(!has_rowwise_scale_inv || !has_columnwise_scale_inv,
"Input tensor has both row-wise and column-wise scaling factors");
if (!has_rowwise_scale_inv && !has_columnwise_scale_inv) {
return;
}

int m{0}, k{0};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[1];
k = input->columnwise_scale_inv.shape[0];
}
break;
}
case NVTE_NVFP4_1D_SCALING: {
if (has_rowwise_scale_inv) {
NVTE_CHECK(input->scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->scale_inv.shape, ".");
m = input->scale_inv.shape[0];
k = input->scale_inv.shape[1];
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(input->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", input->columnwise_scale_inv.shape,
".");
m = input->columnwise_scale_inv.shape[0];
k = input->columnwise_scale_inv.shape[1];
}
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}

constexpr int SF_TILE_DIM_M = 128;
constexpr int SF_TILE_DIM_K = 4;
NVTE_CHECK(m % SF_TILE_DIM_M == 0, "Input should be padded in M/N dimension!");
NVTE_CHECK(k % SF_TILE_DIM_K == 0, "Input should be padded in K dimension!");

if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
}
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
}

bool rowwise_unswizzle{false}, columnwise_unswizzle{false};
switch (scaling_mode) {
case NVTE_MXFP8_1D_SCALING: {
rowwise_unswizzle = has_rowwise_scale_inv;
columnwise_unswizzle = has_columnwise_scale_inv;
break;
}
case NVTE_NVFP4_1D_SCALING: {
rowwise_unswizzle = true;
columnwise_unswizzle = false;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}

const dim3 block_size(TB_DIM, TB_DIM);
const int num_tiles_m = m / SF_TILE_DIM_M;
const int num_tiles_k = k / SF_TILE_DIM_K;

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The code is pretty convoluted here and it doesn't have to be. There are some pieces there that you could do at the beginning without looking at the scaling factor (like checking whether the input has scale_inv/columnwise_scale_inv and checking if the output has them too). For the rest I would say that avoiding code duplication here is not worth breaking of the flow of NVFP4/MXFP8 specific logic, so I would probably just have a larger switch with 2 completely separate code paths rather than multiple switch statements.

Comment on lines +1212 to +1219
if (has_rowwise_scale_inv) {
NVTE_CHECK(output->scale_inv.has_data(),
"Output tensor does not have row-wise scaling factors.");
}
if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.has_data(),
"Output tensor does not have column-wise scaling factors.");
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would say that the logic here is a little backwards, even though I understand how here it is not obvious. Ultimately it is the output that tells you what to do in the function - think about the quantize function where the input does not know anything about the format to which it is quantized and it is the output that controls scaling mode and whether we need rowwise or columnwise quantization. Therefore here I would also treat the output as a "source of truth" on what we need to do and then check that the input tensor provides the right data (as opposed to this code which looks to input to know what to do and then checks the output).

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Chaned this for single tensor. Let me know if that makes sense. Can you tell me how this would be called so that I can check the input and output and how they are allocated. Currently I am assuming from your comment above that the output would have all the necessary information to decide between rowwise, columnwise, scaling_mode and data pointers along with dimensions such as m and k. If this is fine then I can make these changes to multi tensor version as well.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, the changes look good. Please update the multitensor version accordingly.

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +1449 to +1461
kernel_args.block_range[0] = 0;
int vec_load_size = 4;
for (size_t i = 0; i < num_tensors; i++) {
if (kernel_args.num_tensors == kMaxTensorsPerKernel) {
if (vec_load_size == 3) vec_load_size = 1;
launch_multi_tensor_unswizzle_scaling_factors<SF_TILE_DIM_M, SF_TILE_DIM_K>(
kernel_args, vec_load_size, false, stream);
kernel_args.num_tensors = 0;
vec_load_size = 4;
}
const int m = input[i]->columnwise_scale_inv.shape[1];
const int k = input[i]->columnwise_scale_inv.shape[0];

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 original_m_list/original_k_list set but unused by unswizzle kernels

Inside multi_tensor_unswizzle_scaling_factors, both the rowwise path (lines ~955–956 and ~960–961) and the columnwise path (lines ~1006–1007) populate kernel_args.original_m_list[pos] and kernel_args.original_k_list[pos]. However, neither multi_tensor_unswizzle_row_scaling_kernel nor multi_tensor_unswizzle_col_scaling_kernel reads these fields — they only consume m_list and k_list. The swizzle kernels need the original (unpadded) dimensions to zero-fill padding, but the unswizzle kernels always operate on already-padded swizzled input and produce padded compact output, so no masking is required.

Setting unused struct fields is harmless today but adds noise and could mislead a reader into thinking the unswizzle kernels honour padding boundaries the same way the swizzle kernels do. Consider either removing these assignments or adding a comment explaining why they are intentionally populated (e.g., "kept for future per-element padding masking").

Comment on lines +249 to 263
std::vector<std::pair<size_t, size_t>> unswizzle_data_shapes = {
// Aligned: scale dims are already multiples of 128 and 4
{128, 128},
{128, 16896}, // K = 132 * 128, large K
{16896, 128}, // M = 132 * 128, large M
// M-padding only: M not a multiple of 128 (scale-M needs padding to 256)
{160, 128},
// scale-K padding only: K/32 = 3, padded to 4
{128, 96},
// Both M and scale-K need padding
{160, 96},
};

std::vector<std::pair<bool, bool>> scaling_mode = {
{true, false},
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Roundtrip test only covers aligned matrix dimensions

performTestSwizzleUnswizzleRoundtrip is instantiated exclusively with the existing num_tiles vector, which always produces M = num_tiles_M * MAT_TILE_DIM_M — values that are exact multiples of 128 (the scale-M alignment). The standalone performTestUnswizzle1D intentionally adds padded shapes (e.g., M=160, K=96) via unswizzle_data_shapes, but no equivalent padded cases exist for the roundtrip.

If the output-size validation or padding-mask logic ever diverges between the swizzle and unswizzle paths for non-aligned M/K, the roundtrip test would pass while standalone tests fail (or vice-versa). Consider adding a few padded shapes (e.g., {4, 3} tile-count pairs or raw {160, 96} shapes) to num_tiles or creating a separate data-shape vector for the roundtrip suite.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@int-smart Is there a reason for that difference between the tests?

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
…streamline processing. Need to check if rowwise and columnwise both can be true. If yes the if else needs to account for that

Signed-off-by: Abhishek <abhi.dtu11@gmail.com>
Comment on lines +461 to +476
const void* input = kernel_args.input_list[tensor_id];
void* output = kernel_args.output_list[tensor_id];
const int M = kernel_args.m_list[tensor_id];
const int K = kernel_args.k_list[tensor_id];

constexpr int N_TILE_PER_TD = sizeof(LType) / sizeof(int);
constexpr int N_TILES_IN_TB = TB_DIM * N_TILE_PER_TD;

const int num_tiles_k = K / SF_TILE_DIM_K;
const int num_tiles_m = M / SF_TILE_DIM_M;
const int flat_offset = bid - kernel_args.block_range[tensor_id];
const int grid_dim_x = DIVUP(num_tiles_k, N_TILES_IN_TB);
const int grid_dim_y = num_tiles_m;
const int bid_x = flat_offset / grid_dim_y;
const int bid_y = flat_offset % grid_dim_y;

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 unswizzle_col_scaling_kernel_impl: SLM load stride mismatch for multi-block K dimension

The SLM load for each M-tile reads SF_TILE_SIZE_I32 * k_tiles_in_tb contiguous int32s from input_i32[i]:

const int4* input_v4i = reinterpret_cast<const int4*>(input_i32[i]);
for (int j = linear_id; j < SF_TILE_SIZE_I32 * k_tiles_in_tb / 4; j += ...)
    slm_v4i[j] = input_v4i[j];

input_i32[i] is set to base + bid_x * TB_DIM * SF_TILE_SIZE_I32 + mt * SF_TILE_DIM_M_I32 * K_i32, where the stride between adjacent M-tiles in the swizzled layout is SF_TILE_DIM_M_I32 * K_i32.

For a full K-tile block (k_tiles_in_tb == TB_DIM == 32), the write size is 32 * SF_TILE_SIZE_I32 = 32 * 128 = 4096 int32s. The M-tile stride is 32 * K_i32. When K_i32 > 128 (e.g., K = 132 K-tiles), the M-tile stride 32 * 132 = 4224 > 4096, so there is a gap that the read does not cross—it is safe.

However, for the last K-tile block (when bid_x == grid_dim_x - 1) and k_tiles_in_tb < TB_DIM, the following K-tile block for the same M-tile starts at offset (bid_x+1) * TB_DIM * 128 + mt * 32 * K_i32, which is beyond the current read range. This appears correct in isolation, but the stride chosen by the swizzle for that region may leave uninitialised bytes between consecutive partial K-tile writes for the same M-tile.

Consider tracing through with K = 8 K-tiles (K_i32 = 8), M = 128:

  • num_tiles_k = 8 / SF_TILE_DIM_K_I32 = 8 / 4 = 2; grid_dim_x = DIVUP(2, TB_DIM) = 1
  • So bid_x=0 is the last block; k_tiles_in_tb = (2-1) % 32 + 1 = 2
  • Read: 2 * 128 = 256 int32s from 0 + 0 * 32 * 8 = 0
  • Swizzle stored: 2 * 128 = 256 int32s at offset 0

Layouts agree here. But consider K_i32 = 132, grid_dim_x = 2, bid_x = 1 (last), mt = 1:

  • input_i32[1] = 1 * 32 * 128 + 1 * 32 * 132 = 4096 + 4224 = 8320
  • k_tiles_in_tb = (33-1) % 32 + 1 = 1
  • Read: 1 * 128 = 128 int32s from 8320

Swizzle for (bid_x=1, mt=1) wrote 128 int32s starting at 1 * 32 * 128 + 1 * 32 * 132 = 8320. This matches.

After more careful analysis, the reads do appear correct for the cases tested by the test suite (all aligned shapes). However, the correctness relies on k_tiles_in_tb being computed identically in the swizzle and unswizzle kernels. Please add an assertion or comment clarifying the layout invariant assumed by this contiguous read, and add a test covering non-power-of-two K-tile counts (e.g., K = 132 * 4 = 528 with M = 256) to catch any latent mismatch.

Comment on lines +1218 to +1242
k = output->scale_inv.shape[1];
NVTE_CHECK(static_cast<size_t>(m) * k == output->scale_inv.numel(),
"Expected output tensor to have ", static_cast<size_t>(m) * k,
" row-wise scaling factors, but got shape=", output->scale_inv.shape, ".");
input_ptr = input->scale_inv.dptr;
output_ptr = output->scale_inv.dptr;
} else if (has_columnwise_scale_inv) {
NVTE_CHECK(output->columnwise_scale_inv.shape.size() == 2,
"Expected 2D scaling factors, got shape=", output->columnwise_scale_inv.shape,
".");
m = output->columnwise_scale_inv.shape[0];
k = output->columnwise_scale_inv.shape[1];
NVTE_CHECK(static_cast<size_t>(m) * k == output->columnwise_scale_inv.numel(),
"Expected output tensor to have ", static_cast<size_t>(m) * k,
" column-wise scaling factors, but got shape=",
output->columnwise_scale_inv.shape, ".");
input_ptr = input->columnwise_scale_inv.dptr;
output_ptr = output->columnwise_scale_inv.dptr;
}
rowwise = true;
break;
}
default:
NVTE_ERROR("Invalid scaling mode");
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 NVTE_CHECK validation is derived from output, not input — input dimensions are never validated

This same pattern applies to the NVFP4 rowwise and columnwise paths: m/k are extracted from output->...shape and the NVTE_CHECK only tests that m * k == output->...numel() (trivially true), while input->...numel() is never checked against m/k. Adding the symmetrical input-side validation here and in the NVFP4 columnwise block (lines ~1254–1270) will close the gap.

// Example for NVFP4 rowwise path:
NVTE_CHECK(input->scale_inv.numel() == static_cast<size_t>(m) * k,
           "Input NVFP4 row-wise scaling factor size mismatch: got ",
           input->scale_inv.numel(), ", expected ", static_cast<size_t>(m) * k, ".");

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants