-
Notifications
You must be signed in to change notification settings - Fork 617
[PyTorch] Add grouped linear op and experimental fusion for grouped MLP #2622
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet. Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
* Expose option for custom op fusions Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order. Signed-off-by: Tim Moon <[email protected]> * Add tests for custom ops Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Fix linter warnings and numerical test failures Signed-off-by: Tim Moon <[email protected]> * Tweak pattern matching logic with fixed window sizes Signed-off-by: Tim Moon <[email protected]> * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * Use TF32 tols in fused op tests Signed-off-by: Tim Moon <[email protected]> * Review suggestion from @greptile-apps Signed-off-by: Tim Moon <[email protected]> * Backpropagate fixes from #2622 Signed-off-by: Tim Moon <[email protected]> --------- Signed-off-by: Tim Moon <[email protected]> Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
for more information, see https://pre-commit.ci
Signed-off-by: Tim Moon <[email protected]>
|
/te-ci pytorch L1 |
Greptile OverviewGreptile SummaryAdds grouped linear operations and experimental MXFP8 fusion for Mixture-of-Experts grouped MLP blocks. Key Changes:
Minor Issue:
Confidence Score: 4.5/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant GroupedMLP as Grouped MLP Module
participant FC1 as GroupedLinear (FC1)
participant SwiGLU as ScaledSwiGLU
participant FC2 as GroupedLinear (FC2)
participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
participant Quantizer as FP8 Quantizers
User->>GroupedMLP: forward(input, split_sizes, probs)
alt Fusion Available (MXFP8 + SM100+)
GroupedMLP->>FusedOp: fuser_forward(input, split_sizes, probs)
FusedOp->>Quantizer: quantize inputs & weights (MXFP8)
Quantizer-->>FusedOp: quantized tensors
FusedOp->>FusedOp: grouped_gemm_swiglu_kernel()
Note over FusedOp: CuTe DSL kernel fuses:<br/>FC1 GEMM + SwiGLU + scaling
FusedOp->>FC2: grouped GEMM for FC2
FC2-->>FusedOp: output
FusedOp-->>GroupedMLP: final output
else Standard Path
GroupedMLP->>FC1: forward(input, split_sizes)
FC1->>FC1: split input by groups
FC1->>Quantizer: quantize inputs/weights if FP8
FC1->>FC1: general_grouped_gemm()
FC1-->>GroupedMLP: FC1 output
GroupedMLP->>SwiGLU: forward(FC1_out, probs)
SwiGLU->>SwiGLU: remove gate interleaving
SwiGLU->>SwiGLU: swiglu(gate, linear)
SwiGLU->>SwiGLU: multiply by probs (post-scale)
SwiGLU-->>GroupedMLP: scaled output
GroupedMLP->>FC2: forward(SwiGLU_out, split_sizes)
FC2->>FC2: split input by groups
FC2->>Quantizer: quantize inputs/weights if FP8
FC2->>FC2: general_grouped_gemm()
FC2-->>GroupedMLP: final output
end
GroupedMLP-->>User: output
|
This comment was marked as outdated.
This comment was marked as outdated.
Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
| quantizer.optimize_for_gemm = True | ||
| fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers) | ||
|
|
||
| # Pack data tensors |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245
| ) | ||
|
|
||
| # Fused kernel for FC1 + SwiGLU + post-scale | ||
| fc1_kernel_out = self.grouped_gemm_swiglu_kernel()( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, the probs are passed into the kernel here: https://github.com/timmoon10/TransformerEngine/blob/46294be478f6551e2cf251283adc7529ddb2964e/transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py#L264
Signed-off-by: Tim Moon <[email protected]>
Review suggestions from @greptile-apps Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com> Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
Signed-off-by: Tim Moon <[email protected]>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 1 comment
| if swiglu.glu_interleave_size != 32: | ||
| raise ValueError( | ||
| "Fused kernel requires 32-wide GLU interleaving, " | ||
| "but got glu_interleave_size={swiglu.glu_interleave_size}." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
missing f prefix for f-string interpolation
| "but got glu_interleave_size={swiglu.glu_interleave_size}." | |
| f"but got glu_interleave_size={swiglu.glu_interleave_size}." |
Description
This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.
Type of change
Changes
Checklist: