-
Notifications
You must be signed in to change notification settings - Fork 634
[PyTorch] Add ops for MoE grouped MLP #2664
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <tmoon@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds new PyTorch fusible ops required for MoE grouped MLP blocks: a new The changes integrate by exporting the new ops in Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant T as Test
participant GL as GroupedLinear
participant SS as ScaledSwiGLU
participant TE as tex
participant GG as general_grouped_gemm
T->>GL: forward(x, split_sizes)
GL->>TE: split_quantize(x) (fp8 only)
GL->>GG: grouped_gemm fprop
GG-->>GL: out
T->>SS: forward(out, scales)
SS->>TE: swiglu
T->>SS: backward(dy)
SS->>TE: dswiglu
T->>GL: backward(dy)
GL->>GG: grouped_gemm dgrad/wgrad
|
|
/te-ci pytorch |
| return ref, test | ||
|
|
||
|
|
||
| def assert_close( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't this be in some more generic place?
| assert_close(y_test, y_ref, **tols) | ||
| assert_close(x_test.grad, x_ref.grad, **tols) | ||
|
|
||
| def test_interleaved_swiglu(self): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why a separate test rather than a parameter in the other test? Does this only work for specific scenarios?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
None of the other activations support interleaving since it's an implementation detail of the fused GEMM + SwiGLU kernel.
| # Out-of-place concatenation when view tensors have different storage | ||
| # Note: This works around an edge case with the split_quantize | ||
| # function, which might allocate a buffer and construct | ||
| # subviews. However, in order to reduce CPU overheads, these | ||
| # views are configured manually outside of PyTorch. PyTorch | ||
| # doesn't know these views share the same memory, and it | ||
| # blocks us from reconstructing the full tensor because it | ||
| # thinks we are accessing out-of-bounds memory. | ||
| if tensors[0].untyped_storage().nbytes() < out_shape[dim] * data_ptr_stride: | ||
| return torch.cat(tensors, dim=dim) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Should this be done before this big for loop to short circuit it?
| "SiLU", | ||
| "SwiGLU", | ||
| "ClampedSwiGLU", | ||
| ] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like a breaking change in the API? Or we assume that people would always use those ops from the top level?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Accessing via te.ops or te.ops.basic is definitely the best practice. I don't see any downstream code accessing these via te.ops.basic.activation: https://github.com/search?q=ops.basic.activation&type=code
Megatron-LM accesses SwiGLU via te.ops:
https://github.com/NVIDIA/Megatron-LM/blob/24515084a6395bc4c30e20368d47d073868e95e6/megatron/core/extensions/transformer_engine.py#L2076
I don't see ClampedSwiGLU being accessed anywhere.
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Fusible operation for bias.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really :-).
| # | ||
| # See LICENSE for license information. | ||
|
|
||
| """Fusible operation for multiplying with extra input tensor.""" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not really.
|
|
||
|
|
||
| class SwiGLU(BasicOperation): | ||
| r"""Swish gated linear unit |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The doc should include the information about the interleaving parameter and what it means.
| alpha : float | ||
| The scaling factor for the sigmoid function used in the activation. | ||
| cache_quantized_input : bool, default = ``False`` | ||
| Quantize input tensor when caching for use in the backward pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We will need similar interleaving support here pretty soon.
| If the SwiGLU output has shape ``(d_1, ..., d_n)``, it is | ||
| multiplied with an extra input tensor of shape | ||
| ``(d_1, ..., d_{n-1})``. | ||
|
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please add the docs for the interleaving. BTW - is the "cache quantized input" parameter only for Lora or something that it is only in the regular swiglu and nowhere else?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It was an experimental feature for LoRA that we never really ended up using.
| swiglu_out = tex.swiglu(swiglu_in, None) | ||
| out = swiglu_out * scales.unsqueeze(-1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Considering it is implemented with 2 kernels anyway, what is the benefit of having this operation here? I would prefer to have the ScaleWithExtraInput basic op or something like that instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This was the approach I used in my initial implementation (#2605), but it's not compatible with the fused GEMM + SwiGLU kernel (https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py). If we have a standalone scale op, then we need to cache its input for the backward pass. However, the fused kernel assumes you are doing activation recompute and it only outputs the SwiGLU input and scale output. Rather than intertwining the implementations of the SwiGLU and scale to support activation recompute, I just implemented a new op that does it explicitly
Description
This PR adds ops needed for the grouped MLP block in Mixture-of-Experts models. In particular, it adds a grouped linear op (similar to the
GroupedLinearmodule) and aScaledSwiGLUop. It is the same as #2622, but doesn't include the fused ops with experimental kernels. Closes #2560.Type of change
Changes
noop_catfunctionChecklist: