⚠ This page is served via a proxy. Original site: https://github.com
This service does not collect credentials or authentication data.
Skip to content

Conversation

@timmoon10
Copy link
Collaborator

@timmoon10 timmoon10 commented Jan 24, 2026

Description

This PR adds a grouped linear op, which can be used in the grouped MLP block in Mixture-of-Experts models. It also adds an experimental fused operation for a grouped MLP block, using a CuTe DSL kernel that computes an MXFP8 grouped GEMM and SwiGLU.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Add a grouped linear operation
  • Add a post-scaled SwiGLU op and add support for interleaving SwiGLU gate and linear units
  • Add a fused operation for grouped MLP

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

timmoon10 and others added 30 commits January 7, 2026 00:15
Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <[email protected]>
Signed-off-by: Tim Moon <[email protected]>
Test is too permissive since the test should still be failing. The weights are not properly interleaved yet.

Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 added the performance Performance issues label Jan 24, 2026
timmoon10 added a commit to timmoon10/TransformerEngine that referenced this pull request Jan 24, 2026
timmoon10 added a commit that referenced this pull request Jan 25, 2026
* Expose option for custom op fusions

Refactor fusion functions to remove index bookkeeping. Refactor fused ops to use consistent operation order.

Signed-off-by: Tim Moon <[email protected]>

* Add tests for custom ops

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Fix linter warnings and numerical test failures

Signed-off-by: Tim Moon <[email protected]>

* Tweak pattern matching logic with fixed window sizes

Signed-off-by: Tim Moon <[email protected]>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Use TF32 tols in fused op tests

Signed-off-by: Tim Moon <[email protected]>

* Review suggestion from @greptile-apps

Signed-off-by: Tim Moon <[email protected]>

* Backpropagate fixes from #2622

Signed-off-by: Tim Moon <[email protected]>

---------

Signed-off-by: Tim Moon <[email protected]>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
@timmoon10 timmoon10 mentioned this pull request Jan 25, 2026
13 tasks
@timmoon10 timmoon10 changed the title [PyTorch] Prototype of fused operation for grouped MLP [PyTorch] Add grouped linear op and experimental fusion for grouped MLP Jan 25, 2026
Signed-off-by: Tim Moon <[email protected]>
@timmoon10 timmoon10 marked this pull request as ready for review January 25, 2026 01:00
@timmoon10
Copy link
Collaborator Author

/te-ci pytorch L1

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 25, 2026

Greptile Overview

Greptile Summary

Adds grouped linear operations and experimental MXFP8 fusion for Mixture-of-Experts grouped MLP blocks.

Key Changes:

  • Introduced GroupedLinear operation that applies multiple linear transformations by splitting input along first dimension, enabling efficient expert parallelism in MoE models
  • Refactored SwiGLU operations from activation.py into dedicated swiglu.py module, adding ScaledSwiGLU with post-scaling and optional gate/linear unit interleaving
  • Implemented experimental ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 fusion using CuTe DSL kernel from cuDNN (requires SM100+) that fuses grouped GEMM + SwiGLU + post-scale into single kernel
  • Full FP8/MXFP8 quantization support with rowwise/columnwise quantizers throughout the operation chain
  • Comprehensive test coverage including quantization variants, gradient checking, and fusion verification

Minor Issue:

  • Missing f prefix on f-string at line 90 of forward_grouped_mlp.py

Confidence Score: 4.5/5

  • Safe to merge after fixing the f-string syntax issue on line 90
  • Well-architected implementation with comprehensive test coverage. All previously identified issues have been resolved except one minor f-string syntax error. The grouped linear and fusion logic is sound, with proper quantization handling and backward pass implementation.
  • transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py requires fix on line 90

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/basic/grouped_linear.py new file implementing grouped linear operations for MoE models with proper quantization support
transformer_engine/pytorch/ops/basic/swiglu.py refactored SwiGLU operations from activation.py, added ScaledSwiGLU with interleaving support
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py experimental CuTe DSL kernel fusion for MXFP8 grouped MLP, one f-string syntax issue on line 90

Sequence Diagram

sequenceDiagram
    participant User
    participant GroupedMLP as Grouped MLP Module
    participant FC1 as GroupedLinear (FC1)
    participant SwiGLU as ScaledSwiGLU
    participant FC2 as GroupedLinear (FC2)
    participant FusedOp as ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8
    participant Quantizer as FP8 Quantizers

    User->>GroupedMLP: forward(input, split_sizes, probs)
    
    alt Fusion Available (MXFP8 + SM100+)
        GroupedMLP->>FusedOp: fuser_forward(input, split_sizes, probs)
        FusedOp->>Quantizer: quantize inputs & weights (MXFP8)
        Quantizer-->>FusedOp: quantized tensors
        FusedOp->>FusedOp: grouped_gemm_swiglu_kernel()
        Note over FusedOp: CuTe DSL kernel fuses:<br/>FC1 GEMM + SwiGLU + scaling
        FusedOp->>FC2: grouped GEMM for FC2
        FC2-->>FusedOp: output
        FusedOp-->>GroupedMLP: final output
    else Standard Path
        GroupedMLP->>FC1: forward(input, split_sizes)
        FC1->>FC1: split input by groups
        FC1->>Quantizer: quantize inputs/weights if FP8
        FC1->>FC1: general_grouped_gemm()
        FC1-->>GroupedMLP: FC1 output
        
        GroupedMLP->>SwiGLU: forward(FC1_out, probs)
        SwiGLU->>SwiGLU: remove gate interleaving
        SwiGLU->>SwiGLU: swiglu(gate, linear)
        SwiGLU->>SwiGLU: multiply by probs (post-scale)
        SwiGLU-->>GroupedMLP: scaled output
        
        GroupedMLP->>FC2: forward(SwiGLU_out, split_sizes)
        FC2->>FC2: split input by groups
        FC2->>Quantizer: quantize inputs/weights if FP8
        FC2->>FC2: general_grouped_gemm()
        FC2-->>GroupedMLP: final output
    end
    
    GroupedMLP-->>User: output
Loading

greptile-apps[bot]

This comment was marked as resolved.

@greptile-apps

This comment was marked as outdated.

greptile-apps[bot]

This comment was marked as resolved.

greptile-apps[bot]

This comment was marked as resolved.

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
greptile-apps[bot]

This comment was marked as outdated.

quantizer.optimize_for_gemm = True
fc1_xs = tex.split_quantize(fc1_x, split_sizes_cpu, fc1_input_quantizers)

# Pack data tensors
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

May be a silly question: are these packing and unpacking code just for verification? Or will they be in the final version?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm working on getting rid of the concatenations, but the permutes are no-ops. The kernel API expects tensors with non-contiguous dims: https://github.com/NVIDIA/cudnn-frontend/blob/main/python/cudnn/grouped_gemm/grouped_gemm_swiglu/api.py#L240-L245

)

# Fused kernel for FC1 + SwiGLU + post-scale
fc1_kernel_out = self.grouped_gemm_swiglu_kernel()(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

After swiglu, it usually needs to multiply with permuted_probs. Does this weighted swiglu supported?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Signed-off-by: Tim Moon <[email protected]>
greptile-apps[bot]

This comment was marked as resolved.

Review suggestions from @greptile-apps

Co-authored-by: greptile-apps[bot] <165735046+greptile-apps[bot]@users.noreply.github.com>
Signed-off-by: Tim Moon <[email protected]>
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

3 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

if swiglu.glu_interleave_size != 32:
raise ValueError(
"Fused kernel requires 32-wide GLU interleaving, "
"but got glu_interleave_size={swiglu.glu_interleave_size}."
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

missing f prefix for f-string interpolation

Suggested change
"but got glu_interleave_size={swiglu.glu_interleave_size}."
f"but got glu_interleave_size={swiglu.glu_interleave_size}."

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

performance Performance issues

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants