Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

Review suggestion from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <4406448+timmoon10@users.noreply.github.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

timmoon10 and others added 4 commits February 5, 2026 02:18
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as outdated.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
greptile-apps[bot]

This comment was marked as resolved.

Signed-off-by: Tim Moon <tmoon@nvidia.com>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 4 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +290 to +292
quantizer=fc2_input_quantizers[group_idx],
requires_grad=False,
with_gemm_swizzled_scales=True,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect grad-required flags

In ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.fuser_forward, swiglu_ctx.input_requires_grad and swiglu_ctx.extra_input_requires_grad are set to True unconditionally (and input_requires_grad is set to requires_grad unconditionally). This will make ScaledSwiGLU.fuser_backward compute grad_input and grad_extra_input even when neither input_ nor scales require grads, which violates autograd semantics and can raise (e.g., scales.detach() passed into the fused kernel, but extra_input_requires_grad=True forces a gradient).

This should be set based on the actual requirements:

  • input_requires_grad = input_.requires_grad
  • swiglu_ctx.extra_input_requires_grad = scales.requires_grad
  • and for FC weights, check each parameter’s requires_grad (not just weight0).

Comment on lines +420 to +460
# Return immediately if fused kernel is not supported
if not BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
return ops

# Check if recipe is supported
if recipe is None:
return ops
if not recipe.mxfp8():
return ops

# Scan through ops, fusing if possible
out = []
window, ops = ops[:3], ops[3:]
while len(window) == 3:

# Check if window matches pattern
matches_pattern = True
if not (
isinstance(window[0], GroupedLinear)
and isinstance(window[1], ScaledSwiGLU)
and isinstance(window[2], GroupedLinear)
):
matches_pattern = False
elif window[0].has_bias or window[2].has_bias:
matches_pattern = False
elif window[0].num_groups != window[2].num_groups:
matches_pattern = False
elif (
window[0].in_features % 256 != 0
or window[0].out_features % 256 != 0
or window[2].in_features % 256 != 0
or window[2].out_features % 256 != 0
):
matches_pattern = False
elif window[1].glu_interleave_size != 32:
matches_pattern = False

if matches_pattern:
# Construct fused op if window matches pattern
op = BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8(
fc1=window[0],
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Broken fusion window scan

Both fuse_backward_ops and fuse_forward_ops have a window/shift loop that can drop or reorder ops when the pattern doesn’t match. In the non-matching branch you do out.extend(window[:-2]); window = window[-2:] and then immediately do out.extend(window[:-3]) (which is a no-op for a 2-element window) before refilling. This causes the scan to advance by 1 op in some cases and by 2 in others, and it never emits window[-1] until the very end. For sequences like [A,B,C,D] where [A,B,C] doesn’t match but [B,C,D] would (or vice versa), this loop will not correctly consider all 3-op windows and can produce an incorrect fused op list.

This needs a standard sliding-window approach (advance by 1 when not matching; replace 3->1 when matching) to ensure no ops are skipped or duplicated.

Comment on lines +1 to +6
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.

"""Fusible operation for bias."""

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect module docstring

transformer_engine/pytorch/ops/basic/grouped_linear.py starts with """Fusible operation for bias.""", which is a copy/paste from basic/bias.py and is incorrect for this file. This impacts generated docs and module-level help text and is user-visible.

Update it to describe GroupedLinear (e.g., “Fusible operation for grouped linear / grouped GEMM”).

Comment on lines +350 to +357
ctx.dtype = dtype
ctx.save_for_backward(
input_,
scales if ctx.input_requires_grad else None,
)

return out, [()]

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Always-saving scales tensor

In ScaledSwiGLU.fuser_forward, you set ctx.extra_input_requires_grad = extra_input.requires_grad, but you always save scales because the conditional is scales if ctx.input_requires_grad else None and ctx.input_requires_grad is forced to True. When scales.requires_grad=False, this needlessly keeps scales alive for backward and increases activation memory; worse, fuser_backward uses scales inside the ctx.input_requires_grad branch, so if someone later changes input_requires_grad to be accurate, the save condition would become wrong.

Save scales based on ctx.extra_input_requires_grad (or save both unconditionally but keep the grad-required flags consistent).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants