Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
71 commits
Select commit Hold shift + click to select a range
5175aad
Naive implementation of grouped linear op
timmoon10 Jan 7, 2026
5ffd57e
Use grouped GEMM tex functions
timmoon10 Jan 7, 2026
2ee42da
Support quantized compute
timmoon10 Jan 8, 2026
93e71df
Debug test failures with MXFP8 or NVFP4 params
timmoon10 Jan 8, 2026
fdddc47
Add multiply op
timmoon10 Jan 10, 2026
b448a17
Bug fixes
timmoon10 Jan 10, 2026
3f38897
Expose option for custom op fusions
timmoon10 Jan 14, 2026
a359b67
Add tests for custom ops
timmoon10 Jan 14, 2026
5f7204f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 14, 2026
8ddb8ce
Fix linter warnings and numerical test failures
timmoon10 Jan 14, 2026
cfc2617
Tweak pattern matching logic with fixed window sizes
timmoon10 Jan 15, 2026
0ce5dfb
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
9bf5843
Merge branch 'main' into tmoon/custom-fused-ops
timmoon10 Jan 15, 2026
4992903
Use TF32 tols in fused op tests
timmoon10 Jan 15, 2026
9ab7751
Review suggestion from @greptile-apps
timmoon10 Jan 15, 2026
a086d81
Merge branch 'main' into tmoon/custom-fused-ops
timmoon10 Jan 15, 2026
f05f7a8
Merge branch 'main' into tmoon/grouped-linear-op
timmoon10 Jan 15, 2026
9348138
Fix linter warnings
timmoon10 Jan 15, 2026
5366729
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 15, 2026
1b0b229
Merge branch 'tmoon/grouped-linear-op' into tmoon/cute-gemm-swiglu
timmoon10 Jan 15, 2026
3bbe881
Merge branch 'tmoon/custom-fused-ops' into tmoon/cute-gemm-swiglu
timmoon10 Jan 15, 2026
321646e
Initial impl of fused op for grouped MLP
timmoon10 Jan 16, 2026
e137451
Import group GEMM+SwiGLU kernel
timmoon10 Jan 17, 2026
11da59d
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Jan 20, 2026
cb728bb
Add unit test for grouped MLP op
timmoon10 Jan 20, 2026
e7459cc
Call fused group GEMM + SwiGLU kernel
timmoon10 Jan 21, 2026
b15ca0d
Debug test failures
timmoon10 Jan 21, 2026
3da2c17
Get test to not pass trivially
timmoon10 Jan 22, 2026
0270eb1
Handle interleaving for SwiGLU
timmoon10 Jan 22, 2026
0b09790
Fix numeric tests, except for probs grad
timmoon10 Jan 22, 2026
7c40290
Use pre-swizzled scales from GEMM+SwiGLU output
timmoon10 Jan 22, 2026
a098cc0
Add scaled SwiGLU op
timmoon10 Jan 23, 2026
e4f51d3
Avoid CPU splits in group GEMM+SwiGLU kernel
timmoon10 Jan 23, 2026
fb28b6e
Debug scaled SwiGLU
timmoon10 Jan 23, 2026
b0bf34d
Handle case where fused kernel is not available
timmoon10 Jan 24, 2026
000c273
Revert to plain tensor concat
timmoon10 Jan 24, 2026
e2ea4d2
Support GLU interleaving in plain SwiGLU op
timmoon10 Jan 24, 2026
4c6c35f
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 24, 2026
caf580b
Remove MultiplyExtraInput op
timmoon10 Jan 24, 2026
b36007e
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Jan 25, 2026
36e6918
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Jan 25, 2026
ba28c6f
Fix linter warnings
timmoon10 Jan 25, 2026
575da6e
Review suggestions from @greptile-apps
timmoon10 Jan 26, 2026
46294be
Apply suggestion from @greptile-apps[bot]
timmoon10 Jan 26, 2026
fccb0bb
Tweak variable names
timmoon10 Jan 29, 2026
4259e27
Fix f-strings
timmoon10 Jan 30, 2026
2442d34
Fix bug when grouped MLP is not being trained
timmoon10 Jan 30, 2026
a7351e5
Fix f-string
timmoon10 Jan 31, 2026
9be1c49
Replace explicit concat with optional concat
timmoon10 Jan 31, 2026
2c6d6be
Add comments for CuTe DSL expected tensor shapes
timmoon10 Feb 3, 2026
65cb77d
Support contiguous weights in grouped linear op
timmoon10 Feb 4, 2026
2e577d1
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 4, 2026
b02df2b
Update forward kernel API
timmoon10 Feb 5, 2026
49978c6
Add fused op for backward grouped MLP
timmoon10 Feb 6, 2026
790ca14
Debug correctness error in CuTe DSL kernel
timmoon10 Feb 6, 2026
f0b8366
[pre-commit.ci] auto fixes from pre-commit.com hooks
pre-commit-ci[bot] Feb 6, 2026
504f389
Avoid unnecessary weight transpose
timmoon10 Feb 6, 2026
47f942f
Fix linter warnings and comment typos
timmoon10 Feb 6, 2026
ccb091e
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Feb 7, 2026
73d5913
Fix API change in backward kernel
timmoon10 Feb 7, 2026
749f110
Fix dtype mismatch
timmoon10 Feb 8, 2026
89b3283
Fix incorrect FP8 dtype
timmoon10 Feb 9, 2026
61e51d7
Add ops for MoE grouped MLP
timmoon10 Feb 9, 2026
113a163
Move testing utility functions to util submodule
timmoon10 Feb 11, 2026
1466289
Merge branch 'tmoon/grouped-linear-op' into tmoon/cute-gemm-swiglu
timmoon10 Feb 11, 2026
0c2b1bc
Tweak docs
timmoon10 Feb 11, 2026
78e1800
Change order of tensor compatibility checks in noop_cat
timmoon10 Feb 12, 2026
8888100
Merge branch 'main' into tmoon/grouped-linear-op
timmoon10 Feb 12, 2026
da74a35
Add support for GLU interleaving in clamped SwiGLU
timmoon10 Feb 12, 2026
9b97348
Merge branch 'tmoon/grouped-linear-op' into tmoon/cute-gemm-swiglu
timmoon10 Feb 12, 2026
329443c
Merge branch 'main' into tmoon/cute-gemm-swiglu
timmoon10 Feb 12, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions tests/pytorch/test_fusible_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -3380,6 +3380,27 @@ def test_grouped_mlp(
y_test = module(x_test, split_sizes, probs_test, split_sizes)
y_test.backward(dy_test)

# Check for expected fusions
if (
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported()
and quantization == "mxfp8"
and dtype == torch.bfloat16
and not bias
and glu_interleave_size == 32
):
forward_ops = module._module_groups[0]._forward_ops
backward_ops = module._module_groups[0]._backward_ops
assert len(forward_ops) == 1
assert isinstance(
forward_ops[0][0],
te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
)
assert len(backward_ops) == 1
assert isinstance(
backward_ops[0][0],
te_ops.fused.BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
)

# Loose tols for sanity checking
tols = {"rtol": 0.125, "atol": 0.25}
if quantization == "nvfp4":
Expand Down
9 changes: 9 additions & 0 deletions transformer_engine/pytorch/ops/fused/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,3 +28,12 @@
register_backward_fusion(BackwardLinearScale.fuse_backward_ops)
register_backward_fusion(BackwardActivationBias.fuse_backward_ops)
register_backward_fusion(BackwardAddRMSNorm.fuse_backward_ops)

# Import experimental fusions
# Note: Registration logic is non-trivial, so submodule handles it internally.
from .forward_grouped_mlp import ( # pylint: disable=wrong-import-position
ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8,
)
from .backward_grouped_mlp import ( # pylint: disable=wrong-import-position
BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8,
)
Loading
Loading