-
Notifications
You must be signed in to change notification settings - Fork 634
[Common][PyTorch] Add a new score func sqrtsoftplus to the fused router
#2633
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Greptile OverviewGreptile SummaryThis PR adds a new router score function, Key integration points:
Main issues to address before merge are a Python autograd backward return-arity mismatch (runtime error) and numerical stability in the new sqrtsoftplus implementation (possible inf/NaN for large logits). Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as transformer_engine/pytorch/router.py
participant Tex as transformer_engine_torch (pybind)
participant Cpp as pytorch/csrc/extensions/router.cpp
participant CudaTopk as common/fused_router/fused_topk_with_score_function.cu
participant CudaAux as common/fused_router/fused_score_for_moe_aux_loss.cu
Py->>Tex: fused_topk_with_score_function(..., score_function, expert_bias)
Tex->>Cpp: fused_topk_with_score_function_fwd(logits,...)
Cpp->>CudaTopk: nvte_fused_topk_with_score_function_forward(score_function_id)
CudaTopk-->>Cpp: probs, routing_map, intermediate_output(float)
Cpp-->>Tex: return tensors
Tex-->>Py: probs, routing_map
Py->>Tex: fused_topk_with_score_function_bwd(..., grad_probs, grad_logits(out))
Tex->>Cpp: fused_topk_with_score_function_bwd(..., grad_logits)
Cpp->>CudaTopk: nvte_fused_topk_with_score_function_backward(...)
CudaTopk-->>Cpp: writes grad_logits
Py->>Tex: fused_score_for_moe_aux_loss_fwd(logits, topk, score_function)
Tex->>Cpp: fused_score_for_moe_aux_loss_fwd
Cpp->>CudaAux: nvte_fused_score_for_moe_aux_loss_forward
CudaAux-->>Cpp: scores(float), routing_map, intermediate_output(float)
Py->>Tex: fused_score_for_moe_aux_loss_bwd(..., grad_scores, grad_logits(out))
Tex->>Cpp: fused_score_for_moe_aux_loss_bwd(..., grad_logits)
Cpp->>CudaAux: nvte_fused_score_for_moe_aux_loss_backward
CudaAux-->>Cpp: writes grad_logits
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
7 files reviewed, 2 comments
Additional Comments (1)
The header still documents |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu
Outdated
Show resolved
Hide resolved
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
8 files reviewed, 1 comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
Additional Comments (1)
|
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
Signed-off-by: Xin Yao <xiny@nvidia.com>
70136ee to
3bace73
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 3 comments
| // sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) | ||
| __device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { | ||
| for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { | ||
| float x = scores[i]; | ||
| // softplus(x) = log(1 + exp(x)), numerically stable version | ||
| // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) | ||
| float softplus_val; | ||
| if (x > 20.0f) { | ||
| softplus_val = x; // for large x, softplus(x) ≈ x | ||
| } else { | ||
| softplus_val = log1pf(expf(x)); | ||
| } | ||
| scores[i] = sqrtf(softplus_val); | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Softplus overflow for negatives
In apply_sqrtsoftplus_on_float, the else branch uses log1pf(expf(x)). For very negative x (common for gating logits), expf(x) underflows to 0 (fine), but for moderately large positive x below the 20.0f threshold, expf(x) can overflow to inf (e.g., x≈90), making softplus_val inf and producing inf outputs/NaNs downstream. Consider using a numerically stable softplus formulation (e.g., max(x,0)+log1p(exp(-abs(x)))) to avoid overflow while matching PyTorch semantics.
| } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus | ||
| // First save the original logits for backward (needed for sigmoid computation) | ||
| for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { | ||
| intermediate_output[pos_offset + i] = scores[i]; // Save original logits | ||
| } |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Incorrect sqrtsoftplus comment
The comment says “save the original logits for backward (needed for sigmoid computation)”, but the backward path uses these saved values as logits for apply_sqrtsoftplus_bwd_on_float (and recomputes sqrtsoftplus output separately). This is misleading for future maintenance/debugging; update the comment to reflect that intermediate_output stores original logits for sqrtsoftplus gradient computation.
Additional Comments (1)
|
| auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; | ||
| auto score_function_value = score_function_map[score_function]; | ||
| // Init the output tensor | ||
| at::Tensor grad_logits = at::empty( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Creating the output tensor on the Python side because the intermediate_output is now always in FP32, while the grad_logits should have the same dtype as the input.
Description
sqrtsoftplusscoresis always in FP32 (match the MCore implementation).intermediate_outputis always in FP32 for better backward precision.Type of change
Changes
Please list the changes introduced in this PR:
Checklist: