Skip to content

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Jan 29, 2026

Description

  • Added a new score func sqrtsoftplus
  • Add tests
  • Switch to FP32 math instead of FP64 for better performance on SM103. Precision principles:
    • Inputs are casted into FP32 when loading from global memory.
    • All the math/calculations are in FP32 in the kernels.
    • scores is always in FP32 (match the MCore implementation).
    • intermediate_output is always in FP32 for better backward precision.
    • Only cast to low-precision when necessary, for example, the gradient of input is required to have the same dtype as the input.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@yaox12 yaox12 self-assigned this Feb 6, 2026
@yaox12 yaox12 marked this pull request as ready for review February 6, 2026 06:41
@yaox12 yaox12 added the MoE label Feb 6, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Greptile Overview

Greptile Summary

This PR adds a new router score function, sqrtsoftplus, to the fused topk router and aux-loss score computation.

Key integration points:

  • CUDA kernels now support score_function==2 and use fp32 intermediate buffers (intermediate_output, aux-loss scores) while keeping probs/grad_logits in the original dtype.
  • PyTorch bindings are updated so backward APIs write into a caller-provided grad_logits tensor.
  • Unit tests add coverage for sqrtsoftplus for both fused topk routing and aux-loss score paths.

Main issues to address before merge are a Python autograd backward return-arity mismatch (runtime error) and numerical stability in the new sqrtsoftplus implementation (possible inf/NaN for large logits).

Confidence Score: 3/5

  • This PR is close but has at least one definite runtime error and one numerical stability concern to fix before merging.
  • The new score function is wired through CUDA, C++ bindings, Python wrappers, and tests, but FusedTopkScoreFunction.backward returns the wrong number of gradients for its forward signature (will raise in autograd). Additionally, the sqrtsoftplus forward uses an overflow-prone softplus formulation that can produce inf/NaN for large logits.
  • transformer_engine/pytorch/router.py, transformer_engine/common/fused_router/utils.h

Important Files Changed

Filename Overview
tests/pytorch/test_fused_router.py Adds sqrtsoftplus coverage to fused router tests and refactors reference impl; looks consistent with CUDA semantics (fp32 math then cast).
transformer_engine/common/fused_router/fused_moe_aux_loss.cu Changes aux loss accumulation type from double to float; this likely changes numerical results/gradients and may break existing expectations on large token/expert counts.
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu Adds sqrtsoftplus path and moves intermediate/scores buffers to float; backward normalization logic updated. Needs careful verification of dtype assumptions and normalization gradient.
transformer_engine/common/fused_router/fused_topk_with_score_function.cu Adds sqrtsoftplus routing + expert_bias support; changes intermediate_output to float and expert_bias to float-only. Potential correctness risk around intermediate_output semantics and bias dtype constraints.
transformer_engine/common/fused_router/utils.h Adds sqrtsoftplus forward/backward helpers and converts reductions to float; sqrtsoftplus forward uses log1pf(expf(x)) which can overflow for large negative x.
transformer_engine/common/include/transformer_engine/fused_router.h Updates public C API comments to document sqrtsoftplus and expert_bias behavior; no functional changes.
transformer_engine/pytorch/csrc/extensions.h Updates PyTorch extension signatures to pass grad_logits output tensor and switches to std::optional; aligns with new pybind API.
transformer_engine/pytorch/csrc/extensions/pybind.cpp Adjusts router pybind bindings and docstrings for new bwd signatures and naming; looks consistent with C++ entrypoints.
transformer_engine/pytorch/csrc/extensions/router.cpp Adds sqrtsoftplus mapping, forces intermediate_output/scores to fp32, enforces expert_bias fp32, and changes bwd to write into provided grad_logits. Needs alignment with Python wrapper and potential API compatibility.
transformer_engine/pytorch/router.py Extends Python API to sqrtsoftplus and updates bwd to allocate grad_logits and call new out-parameter bindings. Risk: backward now returns 8 grads but forward takes 9 args (missing expert_bias grad slot).

Sequence Diagram

sequenceDiagram
  participant Py as transformer_engine/pytorch/router.py
  participant Tex as transformer_engine_torch (pybind)
  participant Cpp as pytorch/csrc/extensions/router.cpp
  participant CudaTopk as common/fused_router/fused_topk_with_score_function.cu
  participant CudaAux as common/fused_router/fused_score_for_moe_aux_loss.cu

  Py->>Tex: fused_topk_with_score_function(..., score_function, expert_bias)
  Tex->>Cpp: fused_topk_with_score_function_fwd(logits,...)
  Cpp->>CudaTopk: nvte_fused_topk_with_score_function_forward(score_function_id)
  CudaTopk-->>Cpp: probs, routing_map, intermediate_output(float)
  Cpp-->>Tex: return tensors
  Tex-->>Py: probs, routing_map

  Py->>Tex: fused_topk_with_score_function_bwd(..., grad_probs, grad_logits(out))
  Tex->>Cpp: fused_topk_with_score_function_bwd(..., grad_logits)
  Cpp->>CudaTopk: nvte_fused_topk_with_score_function_backward(...)
  CudaTopk-->>Cpp: writes grad_logits

  Py->>Tex: fused_score_for_moe_aux_loss_fwd(logits, topk, score_function)
  Tex->>Cpp: fused_score_for_moe_aux_loss_fwd
  Cpp->>CudaAux: nvte_fused_score_for_moe_aux_loss_forward
  CudaAux-->>Cpp: scores(float), routing_map, intermediate_output(float)

  Py->>Tex: fused_score_for_moe_aux_loss_bwd(..., grad_scores, grad_logits(out))
  Tex->>Cpp: fused_score_for_moe_aux_loss_bwd(..., grad_logits)
  Cpp->>CudaAux: nvte_fused_score_for_moe_aux_loss_backward
  CudaAux-->>Cpp: writes grad_logits
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

7 files reviewed, 2 comments

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 6, 2026

Additional Comments (1)

transformer_engine/common/include/transformer_engine/fused_router.h
Expert bias API mismatch

The header still documents expert_bias as "Only used at the sigmoid case" (line 27), but this PR enables expert_bias for sqrtsoftplus end-to-end (kernels handle score_function==2, and router.cpp allows it). Please update the API docs here (and any other public docs) to reflect that expert_bias is supported for sqrtsoftplus too, otherwise external callers will get contradictory guidance.

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

8 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from denera February 10, 2026 22:37
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

tests/pytorch/test_fused_router.py
Missing expert_bias grad check

run_comparison() enables expert_bias.requires_grad = True for score_function in ("sigmoid", "sqrtsoftplus") (tests/pytorch/test_fused_router.py:177-180), but the test only asserts logits.grad equality (line ~230) and never checks expert_bias.grad vs expert_bias_clone.grad. This will let regressions in the fused backward w.r.t. expert_bias merge undetected for the bias-enabled paths.

Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 force-pushed the xiny/add_score_func branch from 70136ee to 3bace73 Compare February 11, 2026 09:53
Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

10 files reviewed, 3 comments

Edit Code Review Agent Settings | Greptile

Comment on lines +111 to +124
// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x)))
__device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) {
for (int i = lane_id; i < data_size; i += kThreadsPerWarp) {
float x = scores[i];
// softplus(x) = log(1 + exp(x)), numerically stable version
// Matches PyTorch's Softplus(beta=1.0, threshold=20.0)
float softplus_val;
if (x > 20.0f) {
softplus_val = x; // for large x, softplus(x) ≈ x
} else {
softplus_val = log1pf(expf(x));
}
scores[i] = sqrtf(softplus_val);
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Softplus overflow for negatives

In apply_sqrtsoftplus_on_float, the else branch uses log1pf(expf(x)). For very negative x (common for gating logits), expf(x) underflows to 0 (fine), but for moderately large positive x below the 20.0f threshold, expf(x) can overflow to inf (e.g., x≈90), making softplus_val inf and producing inf outputs/NaNs downstream. Consider using a numerically stable softplus formulation (e.g., max(x,0)+log1p(exp(-abs(x)))) to avoid overflow while matching PyTorch semantics.

Comment on lines +114 to +118
} else if (score_function == 2) { // score_function == 2 means sqrtsoftplus
// First save the original logits for backward (needed for sigmoid computation)
for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) {
intermediate_output[pos_offset + i] = scores[i]; // Save original logits
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Incorrect sqrtsoftplus comment

The comment says “save the original logits for backward (needed for sigmoid computation)”, but the backward path uses these saved values as logits for apply_sqrtsoftplus_bwd_on_float (and recomputes sqrtsoftplus output separately). This is misleading for future maintenance/debugging; update the comment to reflect that intermediate_output stores original logits for sqrtsoftplus gradient computation.

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Additional Comments (1)

transformer_engine/pytorch/router.py
Backward returns wrong arity

FusedTopkScoreFunction.forward takes 9 inputs after ctx (including expert_bias), but backward currently returns only 8 entries (grad_logits, None, None, None, None, None, None, None). Autograd expects one gradient slot per forward input, so this will raise at runtime. Add a trailing None for the expert_bias gradient (and keep the ordering aligned with the forward signature).

auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f;
auto score_function_value = score_function_map[score_function];
// Init the output tensor
at::Tensor grad_logits = at::empty(
Copy link
Member Author

@yaox12 yaox12 Feb 11, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Creating the output tensor on the Python side because the intermediate_output is now always in FP32, while the grad_logits should have the same dtype as the input.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant