diff --git a/tests/pytorch/test_fused_router.py b/tests/pytorch/test_fused_router.py index f559362d82..64000e109e 100644 --- a/tests/pytorch/test_fused_router.py +++ b/tests/pytorch/test_fused_router.py @@ -47,7 +47,7 @@ def group_limited_topk( # Pytorch-based topk softmax/sigmoid -def topk_softmax_sigmoid_pytorch( +def topk_score_function_pytorch( logits: torch.Tensor, topk: int, use_pre_softmax: bool = False, @@ -74,17 +74,20 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if score_function == "softmax": if use_pre_softmax: - scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits) + scores = torch.softmax(logits, dim=-1, dtype=torch.float32) probs, top_indices = compute_topk(scores, topk, num_groups, group_topk) else: scores, top_indices = compute_topk(logits, topk, num_groups, group_topk) - probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits) - elif score_function == "sigmoid": - scores = torch.sigmoid(logits.float()).type_as(logits) + probs = torch.softmax(scores, dim=-1, dtype=torch.float32) + elif score_function in ("sigmoid", "sqrtsoftplus"): + if score_function == "sigmoid": + scores = torch.sigmoid(logits.float()) + else: + scores = torch.nn.functional.softplus(logits.float()).sqrt() if expert_bias is not None: scores_for_routing = scores + expert_bias _, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk) - scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits) + scores = torch.gather(scores, dim=1, index=top_indices) else: scores, top_indices = compute_topk(scores, topk, num_groups, group_topk) probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores @@ -94,6 +97,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None): if scaling_factor: probs = probs * scaling_factor + probs = probs.type_as(logits) + topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs) topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool() @@ -107,7 +112,10 @@ def compute_scores_for_aux_loss_pytorch( if score_function == "softmax": scores = torch.softmax(logits, dim=-1, dtype=torch.float32) elif score_function == "sigmoid": - scores = torch.sigmoid(logits) + scores = torch.sigmoid(logits.float()) + scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores + elif score_function == "sqrtsoftplus": + scores = torch.nn.functional.softplus(logits.float()).sqrt() scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores else: raise ValueError(f"Invalid score_function: {score_function}") @@ -146,8 +154,9 @@ def run_comparison( enable_bias, ): # Set some parameters - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values, + # which would cause precision loss when adding/subtracting expert bias in float32. offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -165,8 +174,8 @@ def run_comparison( ) logits = logits.view(num_tokens, num_experts) logits.requires_grad = True - if enable_bias and score_function == "sigmoid": - expert_bias = torch.arange(num_experts, device="cuda") * 0.1 + if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"): + expert_bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1 expert_bias = torch.flip(expert_bias, dims=[0]) expert_bias.requires_grad = True else: @@ -183,7 +192,7 @@ def run_comparison( # Run the original implementation # We do not support the capacity factor case - probs, routing_map = topk_softmax_sigmoid_pytorch( + probs, routing_map = topk_score_function_pytorch( logits=logits, topk=topk, use_pre_softmax=use_pre_softmax, @@ -252,6 +261,37 @@ def test_topk_sigmoid( ) +@pytest.mark.parametrize("dtype", [torch.float32]) +@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992]) +@pytest.mark.parametrize("num_experts", [128, 32]) +@pytest.mark.parametrize("topk", [4, 8]) +@pytest.mark.parametrize("group_topk", [None, 4]) +@pytest.mark.parametrize("scaling_factor", [None, 1.2]) +@pytest.mark.parametrize("enable_bias", [True, False]) +def test_topk_sqrtsoftplus( + dtype, + num_tokens, + num_experts, + topk, + group_topk, + scaling_factor, + enable_bias, +): + num_groups = 8 if group_topk else None + run_comparison( + dtype=dtype, + num_tokens=num_tokens, + num_experts=num_experts, + topk=topk, + use_pre_softmax=False, + num_groups=num_groups, + group_topk=group_topk, + scaling_factor=scaling_factor, + score_function="sqrtsoftplus", + enable_bias=enable_bias, + ) + + @pytest.mark.parametrize("dtype", [torch.float32]) @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [128, 32]) @@ -287,10 +327,10 @@ def test_topk_softmax( @pytest.mark.parametrize("num_tokens", [2048, 7168, 14234]) @pytest.mark.parametrize("num_experts", [256, 128, 32]) @pytest.mark.parametrize("topk", [4, 8]) -@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"]) +@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"]) def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function): - if score_function == "sigmoid": - # Construct the special logits to avoid inf in the sigmoid function + if score_function in ("sigmoid", "sqrtsoftplus"): + # Construct logits with a narrow range to avoid very small activation values offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4 logits = ( torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2 @@ -396,15 +436,6 @@ def profile_topk_softmax( test_topk_softmax( torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor ) - - -if __name__ == "__main__": - test_topk_softmax( - dtype=torch.float32, - num_tokens=1024, - num_experts=128, - topk=4, - use_pre_softmax=False, - group_topk=None, - scaling_factor=None, + test_topk_sqrtsoftplus( + torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias ) diff --git a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu index 2aa2805fed..e22ca4aed5 100644 --- a/transformer_engine/common/fused_router/fused_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_moe_aux_loss.cu @@ -18,7 +18,7 @@ namespace transformer_engine { // Using Double to hanld all the calculations -using CompType = double; +using CompType = float; template __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, @@ -98,7 +98,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Compute the aux_loss */ float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + aux_loss[0] = static_cast(intermediate_result * C_coeff); Const_buf[0] = C_coeff; } } @@ -154,7 +154,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs, * Section: Compute the aux_loss */ float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens; - aux_loss[0] = static_cast(static_cast(intermediate_result) * C_coeff); + aux_loss[0] = static_cast(intermediate_result * C_coeff); Const_buf[0] = C_coeff; } } @@ -229,8 +229,8 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf, // Loop: for all positions in each row for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) { float C_coeff = Const_buf[0]; - double tokens_per_expert_i = static_cast(tokens_per_expert[i]); - double grad_aux_loss_value = static_cast(grad_aux_loss[0]); + CompType tokens_per_expert_i = static_cast(tokens_per_expert[i]); + CompType grad_aux_loss_value = static_cast(grad_aux_loss[0]); // Loop: for all rows for (int j = global_warp_id; j < num_rows; j += global_warp_num) { grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value; diff --git a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu index 7540b5c41d..cf0eafca84 100644 --- a/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu +++ b/transformer_engine/common/fused_router/fused_score_for_moe_aux_loss.cu @@ -18,13 +18,11 @@ namespace transformer_engine { template __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logits, int num_tokens, int num_experts, int topk, - int score_function, DataType *scores, + int score_function, float *scores, bool *routing_map, - DataType *intermediate_output) { + float *intermediate_output) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), - * So DataType address is assigned firstly to avoid the alignment issue * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -33,13 +31,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem_scores_for_aux_loss[]; - DataType *logits_buf = reinterpret_cast(shmem_scores_for_aux_loss); - DataType *topk_logits_buf = - reinterpret_cast(logits_buf + num_experts * num_token_per_block); + float *logits_buf = shmem_scores_for_aux_loss; + float *topk_logits_buf = logits_buf + num_experts * num_token_per_block; int *topk_indices_buf = reinterpret_cast(topk_logits_buf + topk * num_token_per_block); // The address of buffers on the current warp - DataType *local_logits = logits_buf + warp_id * num_experts; - DataType *topk_logits = topk_logits_buf + warp_id * topk; + float *local_logits = logits_buf + warp_id * num_experts; + float *topk_logits = topk_logits_buf + warp_id * topk; int *topk_indices = topk_indices_buf + warp_id * topk; /*** @@ -63,12 +60,12 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { routing_map[pos_offset + i] = false; if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } // Load the logits to shmem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = logits[pos_offset + i]; + local_logits[i] = static_cast(logits[pos_offset + i]); } __threadfence_block(); __syncwarp(); @@ -78,11 +75,11 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid - * - Sigmoid post-processing when topk > 1 + * - Sqrtsoftplus + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * This is in-place scores update */ - // score_function == 1 means softmax - if (score_function == 1) { + if (score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -90,10 +87,7 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(local_logits, num_experts, lane_id); __syncwarp(); @@ -101,17 +95,25 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = local_logits[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for gradient computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = local_logits[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(local_logits, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - if (score_function == 0) { + // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + if (score_function == 0 || score_function == 2) { if (topk > 1) { auto sum_logits = warp_reduce_on_shmem(local_logits, num_experts, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_logits[i] = static_cast(static_cast(local_logits[i]) / - (static_cast(sum_logits) + epsilon)); + local_logits[i] /= (sum_logits + epsilon); } } __syncwarp(); @@ -140,13 +142,13 @@ __global__ void fused_score_for_moe_aux_loss_forward_kernel(const DataType *logi template void fused_score_for_moe_aux_loss_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, int score_function, - DataType *scores, bool *routing_map, DataType *intermediate_output, cudaStream_t stream) { + float *scores, bool *routing_map, float *intermediate_output, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // logits - + topk * num_token_per_block * sizeof(DataType) // topk_logits - + topk * num_token_per_block * sizeof(int); // topk_indices + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(float) // logits + + topk * num_token_per_block * sizeof(float) // topk_logits + + topk * num_token_per_block * sizeof(int); // topk_indices fused_score_for_moe_aux_loss_forward_kernel <<>>( logits, num_tokens, num_experts, topk, score_function, scores, routing_map, @@ -162,20 +164,19 @@ void fused_score_for_moe_aux_loss_forward(const Tensor &logits, int num_tokens, logits.data.dtype, DataType, fused_score_for_moe_aux_loss_forward_kernel_launcher( reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - score_function, reinterpret_cast(scores.data.dptr), + score_function, reinterpret_cast(scores.data.dptr), reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), stream);); + reinterpret_cast(intermediate_output.data.dptr), stream);); } template -__global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *intermediate_output, - const DataType *grad_scores, +__global__ void fused_score_for_moe_aux_loss_backward_kernel(const float *intermediate_output, + const float *grad_scores, int num_tokens, int num_experts, int topk, int score_function, DataType *grad_logits) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -184,16 +185,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *grad_scores_buf = reinterpret_cast(shmem); + float *grad_scores_buf = shmem; // To store the output of softmax/sigmoid from the fwd - DataType *act_from_fwd_buf = - reinterpret_cast(grad_scores_buf + num_experts * num_token_per_block); - DataType *comp_buf = - reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + float *act_from_fwd_buf = grad_scores_buf + num_experts * num_token_per_block; + float *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; // The address of buffers on the current warp - DataType *local_grad = grad_scores_buf + warp_id * num_experts; - DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - DataType *local_comp_buf = comp_buf + warp_id * num_experts; + float *local_grad = grad_scores_buf + warp_id * num_experts; + float *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + float *local_comp_buf = comp_buf + warp_id * num_experts; /*** * Section: Main Loop @@ -227,31 +226,50 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int /*** * Section: Backward of ops before the topk * - Pre-softmax bwd - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - auto sum_fwd_input = - warp_reduce_on_shmem(local_act_from_fwd, num_experts, ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = local_grad[i] * local_act_from_fwd[i]; + local_comp_buf[i] = local_act_from_fwd[i]; } __syncwarp(); - auto sum_Output_x_Grad = - warp_reduce_on_shmem(local_comp_buf, num_experts, ReduceFuncType::SUM, lane_id); + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + float *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + + auto sum_fwd_input = + warp_reduce_on_shmem(act_output, num_experts, ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers + float local_sum_Output_x_Grad = 0.0f; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + } + float sum_Output_x_Grad = local_sum_Output_x_Grad; // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_grad[i] = - static_cast(local_grad[i]) / (static_cast(sum_fwd_input) + epsilon) - - static_cast(sum_Output_x_Grad) / - ((static_cast(sum_fwd_input) + epsilon) * - (static_cast(sum_fwd_input) + epsilon)); + local_grad[i] = local_grad[i] / (sum_fwd_input + epsilon) - + sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); } + __syncwarp(); } - __syncwarp(); // Pre-softmax bwd if (score_function == 1) { @@ -264,9 +282,17 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - grad_logits[pos_offset + i] = local_grad[i]; + grad_logits[pos_offset + i] = static_cast(local_grad[i]); } __syncwarp(); } @@ -274,15 +300,14 @@ __global__ void fused_score_for_moe_aux_loss_backward_kernel(const DataType *int template void fused_score_for_moe_aux_loss_backward_kernel_launcher( - const DataType *intermediate_output, const DataType *grad_scores, int num_tokens, - int num_experts, int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { + const float *intermediate_output, const float *grad_scores, int num_tokens, int num_experts, + int topk, int score_function, DataType *grad_logits, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_scores - + - num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(DataType); // comp_buf + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(float) // grad_scores + + num_experts * num_token_per_block * sizeof(float) // act_from_fwd + + num_experts * num_token_per_block * sizeof(float); // comp_buf fused_score_for_moe_aux_loss_backward_kernel <<>>( intermediate_output, grad_scores, num_tokens, num_experts, topk, score_function, @@ -297,8 +322,8 @@ void fused_score_for_moe_aux_loss_backward(const Tensor &intermediate_output, TE_ROUTER_PROBS_TYPE_SWITCH_ALL( grad_scores.data.dtype, DataType, fused_score_for_moe_aux_loss_backward_kernel_launcher( - reinterpret_cast(intermediate_output.data.dptr), - reinterpret_cast(grad_scores.data.dptr), num_tokens, num_experts, topk, + reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(grad_scores.data.dptr), num_tokens, num_experts, topk, score_function, reinterpret_cast(grad_logits.data.dptr), stream);); } diff --git a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu index 2719c68c97..2a11bdfeb1 100644 --- a/transformer_engine/common/fused_router/fused_topk_with_score_function.cu +++ b/transformer_engine/common/fused_router/fused_topk_with_score_function.cu @@ -15,16 +15,13 @@ namespace transformer_engine { -template +template __global__ void fused_topk_with_score_function_forward_kernel( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, - DataType *intermediate_output) { + const float *expert_bias, DataType *probs, bool *routing_map, float *intermediate_output) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), - * So DataType address is assigned firstly to avoid the alignment issue * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -33,24 +30,22 @@ __global__ void fused_topk_with_score_function_forward_kernel( int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *scores_buf = reinterpret_cast(shmem); - DataType *topk_scores_buf = - reinterpret_cast(scores_buf + num_experts * num_token_per_block); - DataType *group_scores_buf = nullptr, *masked_scores_buf = nullptr; + float *scores_buf = shmem; + float *topk_scores_buf = scores_buf + num_experts * num_token_per_block; + float *group_scores_buf = nullptr, *masked_scores_buf = nullptr; int *topk_indices_buf = nullptr; if (group_topk > 0) { - masked_scores_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); - group_scores_buf = - reinterpret_cast(masked_scores_buf + num_experts * num_token_per_block); + masked_scores_buf = topk_scores_buf + topk * num_token_per_block; + group_scores_buf = masked_scores_buf + num_experts * num_token_per_block; topk_indices_buf = reinterpret_cast(group_scores_buf + num_groups * num_token_per_block); } else { topk_indices_buf = reinterpret_cast(topk_scores_buf + topk * num_token_per_block); } // The address of buffers on the current warp - DataType *scores = scores_buf + warp_id * num_experts; - DataType *topk_scores = topk_scores_buf + warp_id * topk; - DataType *masked_scores = masked_scores_buf + warp_id * num_experts; - DataType *group_scores = group_scores_buf + warp_id * num_groups; + float *scores = scores_buf + warp_id * num_experts; + float *topk_scores = topk_scores_buf + warp_id * topk; + float *masked_scores = masked_scores_buf + warp_id * num_experts; + float *group_scores = group_scores_buf + warp_id * num_groups; int *topk_indices = topk_indices_buf + warp_id * topk; /*** @@ -75,7 +70,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( probs[pos_offset + i] = 0.0f; routing_map[pos_offset + i] = false; if (score_function == 1) { - intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); + intermediate_output[pos_offset + i] = -std::numeric_limits::infinity(); } } // Load the logits to shmem @@ -85,7 +80,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( // If group_topk > 0, init the masked_scores to -inf if (group_topk > 0) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - masked_scores[i] = -std::numeric_limits::infinity(); + masked_scores[i] = -std::numeric_limits::infinity(); } } __threadfence_block(); @@ -96,11 +91,11 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible preprocess the scores before the topk operation * - Pre-softmax * - Sigmoid + * - Sqrtsoftplus * - Expert bias * This is in-place scores update */ - // score_function == 1 means softmax - if (use_pre_softmax && score_function == 1) { + if (use_pre_softmax && score_function == 1) { // score_function == 1 means softmax // Apply softmax to the logits before the topk apply_softmax_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -108,10 +103,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } - } - - // score_function == 0 means sigmoid - if (score_function == 0) { + } else if (score_function == 0) { // score_function == 0 means sigmoid // Apply sigmoid to the logits apply_sigmoid_on_float(scores, num_experts, lane_id); __syncwarp(); @@ -119,18 +111,25 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { intermediate_output[pos_offset + i] = scores[i]; } + } else if (score_function == 2) { // score_function == 2 means sqrtsoftplus + // First save the original logits for backward (needed for sigmoid computation) + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + intermediate_output[pos_offset + i] = scores[i]; // Save original logits + } + __syncwarp(); + // Apply sqrtsoftplus to the logits + apply_sqrtsoftplus_on_float(scores, num_experts, lane_id); } - __syncwarp(); //Confirm the scores is written to the softmax/sigmoid output + __syncwarp(); //Confirm the scores is written to the output - // Expert bias is only used at the sigmoid case - if (expert_bias && score_function == 0) { + // Expert bias is only used at the sigmoid/sqrtsoftplus case + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - scores[i] = static_cast(static_cast(scores[i]) + - static_cast(expert_bias[i])); + scores[i] += static_cast(expert_bias[i]); } + __syncwarp(); } - __syncwarp(); /*** * Section: Topk @@ -140,7 +139,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( * - topk with expert bias */ // Topk on the scores - // The bias is not empty only happens at the sigmod case + // The bias being not empty happens at the sigmoid/sqrtsoftplus case if (group_topk > 0) { int group_size = num_experts / num_groups; // Top2 @@ -155,7 +154,7 @@ __global__ void fused_topk_with_score_function_forward_kernel( __syncwarp(); // Compute the group score if (lane_id == 0) { - DataType tmp = 0.0f; + float tmp = 0.0f; for (int j = 0; j < topk / group_topk; j++) { tmp = tmp + topk_scores[j]; } @@ -194,17 +193,16 @@ __global__ void fused_topk_with_score_function_forward_kernel( * Possible postprocess the scores after the topk operation * - Revert Expert bias * - Softmax - * - Sigmoid post-processing when topk > 1 + * - Sigmoid/Sqrtsoftplus post-processing when topk > 1 * - Write the result with scaling_factor */ // Revert Expert bias from the topk scores - if (expert_bias && score_function == 0) { + if (expert_bias && (score_function == 0 || score_function == 2)) { for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = - static_cast(topk_scores[i]) - static_cast(expert_bias[topk_indices[i]]); + topk_scores[i] = topk_scores[i] - expert_bias[topk_indices[i]]; } + __syncwarp(); } - __syncwarp(); // score_function == 1 means softmax if (!use_pre_softmax && score_function == 1) { @@ -215,14 +213,15 @@ __global__ void fused_topk_with_score_function_forward_kernel( for (int i = lane_id; i < topk; i += kThreadsPerWarp) { intermediate_output[pos_offset + topk_indices[i]] = topk_scores[i]; } + __syncwarp(); } - // score_function == 0 means sigmoid - if (score_function == 0) { + // Sigmoid/Sqrtsoftplus post-processing when topk > 1 + if (score_function == 0 || score_function == 2) { if (topk > 1) { - double sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); + float sum_scores = warp_reduce_on_shmem(topk_scores, topk, ReduceFuncType::SUM, lane_id); for (int i = lane_id; i < topk; i += kThreadsPerWarp) { - topk_scores[i] = static_cast(topk_scores[i]) / (sum_scores + epsilon); + topk_scores[i] = topk_scores[i] / (sum_scores + epsilon); } } __syncwarp(); @@ -231,29 +230,29 @@ __global__ void fused_topk_with_score_function_forward_kernel( // Write the probs/routing_map to the output tensor for (int i = lane_id; i < topk; i += kThreadsPerWarp) { routing_map[pos_offset + topk_indices[i]] = true; - probs[pos_offset + topk_indices[i]] = scaling_factor * static_cast(topk_scores[i]); + probs[pos_offset + topk_indices[i]] = scaling_factor * topk_scores[i]; } __threadfence_block(); __syncwarp(); } } -template +template void fused_topk_with_score_function_forward_kernel_launcher( const DataType *logits, int num_tokens, int num_experts, int topk, bool use_pre_softmax, int num_groups, int group_topk, float scaling_factor, int score_function, - const BiasType *expert_bias, DataType *probs, bool *routing_map, DataType *intermediate_output, + const float *expert_bias, DataType *probs, bool *routing_map, float *intermediate_output, cudaStream_t stream) { size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // scores - + topk * num_token_per_block * sizeof(DataType) // topk_scores - + topk * num_token_per_block * sizeof(int); // topk_indices + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(float) // scores + + topk * num_token_per_block * sizeof(float) // topk_scores + + topk * num_token_per_block * sizeof(int); // topk_indices if (group_topk > 0) { - shared_memory_size += num_groups * num_token_per_block * sizeof(DataType); // group_scores - shared_memory_size += num_experts * num_token_per_block * sizeof(DataType); // maksed_scores + shared_memory_size += num_groups * num_token_per_block * sizeof(float); // group_scores + shared_memory_size += num_experts * num_token_per_block * sizeof(float); // maksed_scores } - fused_topk_with_score_function_forward_kernel + fused_topk_with_score_function_forward_kernel <<>>( logits, num_tokens, num_experts, topk, use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, expert_bias, probs, routing_map, intermediate_output); @@ -268,21 +267,19 @@ void fused_topk_with_score_function_forward(const Tensor logits, int num_tokens, Tensor intermediate_output, cudaStream_t stream) { TE_ROUTER_PROBS_TYPE_SWITCH_ALL( logits.data.dtype, DataType, - TE_ROUTER_PROBS_TYPE_SWITCH_ALL( - expert_bias.data.dtype, BiasType, - fused_topk_with_score_function_forward_kernel_launcher( - reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, - use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, - reinterpret_cast(expert_bias.data.dptr), - reinterpret_cast(probs.data.dptr), - reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), stream););); + fused_topk_with_score_function_forward_kernel_launcher( + reinterpret_cast(logits.data.dptr), num_tokens, num_experts, topk, + use_pre_softmax, num_groups, group_topk, scaling_factor, score_function, + reinterpret_cast(expert_bias.data.dptr), + reinterpret_cast(probs.data.dptr), + reinterpret_cast(routing_map.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), stream);); } template __global__ void fused_topk_with_score_function_backward_kernel( // Inputs tensor - const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + const bool *routing_map, const float *intermediate_output, const DataType *grad_probs, // Other parameters int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, @@ -290,7 +287,6 @@ __global__ void fused_topk_with_score_function_backward_kernel( DataType *grad_logits) { /*** * Section: Global Variables/Addresses init - * - Assume the sizeof(DataType) >= sizeof(int), * - Each warp is responsible for one token, and has own shared memory buffer. * Then __syncwarp() is used instead of __syncthreads() */ @@ -299,18 +295,16 @@ __global__ void fused_topk_with_score_function_backward_kernel( int warp_id = threadIdx.x / kThreadsPerWarp; int lane_id = threadIdx.x % kThreadsPerWarp; extern __shared__ float shmem[]; - DataType *grad_probs_buf = reinterpret_cast(shmem); + float *grad_probs_buf = shmem; // To store the output of softmax/sigmoid from the fwd - DataType *act_from_fwd_buf = - reinterpret_cast(grad_probs_buf + num_experts * num_token_per_block); - DataType *comp_buf = - reinterpret_cast(act_from_fwd_buf + num_experts * num_token_per_block); + float *act_from_fwd_buf = grad_probs_buf + num_experts * num_token_per_block; + float *comp_buf = act_from_fwd_buf + num_experts * num_token_per_block; // To store the routing_map from the fwd bool *routing_map_buf = reinterpret_cast(comp_buf + num_experts * num_token_per_block); // The address of buffers on the current warp - DataType *local_grad = grad_probs_buf + warp_id * num_experts; - DataType *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; - DataType *local_comp_buf = comp_buf + warp_id * num_experts; + float *local_grad = grad_probs_buf + warp_id * num_experts; + float *local_act_from_fwd = act_from_fwd_buf + warp_id * num_experts; + float *local_comp_buf = comp_buf + warp_id * num_experts; bool *local_routing_map = routing_map_buf + warp_id * num_experts; /*** @@ -346,48 +340,68 @@ __global__ void fused_topk_with_score_function_backward_kernel( /*** * Section: Backward of ops after the topk * - Backward of the used scaling_factor - * - Sigmoid Post-processing bwd when topk > 1 + * - Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 * - Softmax bwd if use_pre_softmax is false */ // Backward of the used scaling_factor // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { - local_grad[i] = static_cast(local_grad[i]) * scaling_factor; + local_grad[i] = local_grad[i] * scaling_factor; } } __syncwarp(); - // Sigmoid Post-processing bwd when topk > 1 - if (topk > 1 && score_function == 0) { - double sum_fwd_input = masked_warp_reduce_on_shmem( - /*data ptr = */ local_act_from_fwd, - /*mask ptr = */ local_routing_map, - /*data size = */ num_experts, - /*reduce func = */ ReduceFuncType::SUM, lane_id); - // Put the result of output * grad to the comp_buf + + // Sqrtsoftplus: First compute sqrtsoftplus output from original logits + // (needed for both post-processing bwd and activation bwd, compute once here) + // For sqrtsoftplus, intermediate_output stores original logits + if (score_function == 2) { + // Copy original logits to local_comp_buf and apply sqrtsoftplus in-place for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { - local_comp_buf[i] = (local_routing_map[i] ? static_cast(local_grad[i]) * - static_cast(local_act_from_fwd[i]) - : 0.0f); + local_comp_buf[i] = local_act_from_fwd[i]; } __syncwarp(); - double sum_Output_x_Grad = masked_warp_reduce_on_shmem( - /*data ptr = */ local_comp_buf, + apply_sqrtsoftplus_on_float(local_comp_buf, num_experts, lane_id); + __syncwarp(); + } + + // Sigmoid/Sqrtsoftplus Post-processing bwd when topk > 1 (normalization backward) + if (topk > 1 && (score_function == 0 || score_function == 2)) { + // Select the correct activation output buffer: + // - Sigmoid: local_act_from_fwd already contains sigmoid output + // - Sqrtsoftplus: local_comp_buf contains sqrtsoftplus output computed above + float *act_output = (score_function == 0) ? local_act_from_fwd : local_comp_buf; + + float sum_fwd_input = masked_warp_reduce_on_shmem( + /*data ptr = */ act_output, /*mask ptr = */ local_routing_map, /*data size = */ num_experts, /*reduce func = */ ReduceFuncType::SUM, lane_id); + // Compute sum of output * grad using registers + float local_sum_Output_x_Grad = 0.0f; + for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { + if (local_routing_map[i]) { + local_sum_Output_x_Grad += local_grad[i] * act_output[i]; + } + } + // Warp reduce the sum + for (int s = 16; s > 0; s /= 2) { + local_sum_Output_x_Grad += __shfl_xor_sync(0xffffffff, local_sum_Output_x_Grad, s); + } + float sum_Output_x_Grad = local_sum_Output_x_Grad; // In-place update for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { if (local_routing_map[i]) { local_grad[i] = - static_cast(local_grad[i]) / (sum_fwd_input + epsilon) - + local_grad[i] / (sum_fwd_input + epsilon) - sum_Output_x_Grad / ((sum_fwd_input + epsilon) * (sum_fwd_input + epsilon)); } else { local_grad[i] = 0.0f; } } + __syncwarp(); } - __syncwarp(); + // Softmax bwd if use_pre_softmax is false if (!use_pre_softmax && score_function == 1) { apply_softmax_bwd_on_float(local_grad, local_act_from_fwd, local_comp_buf, local_routing_map, @@ -410,6 +424,7 @@ __global__ void fused_topk_with_score_function_backward_kernel( * Section: Backward of ops before the topk * - Pre-softmax bwd * - Sigmoid bwd + * - Sqrtsoftplus bwd * - Write the grad_logits to the global mem */ // Pre-softmax bwd @@ -423,6 +438,14 @@ __global__ void fused_topk_with_score_function_backward_kernel( apply_sigmoid_bwd_on_float(local_grad, local_act_from_fwd, num_experts, lane_id); __syncwarp(); } + // Sqrtsoftplus bwd + // For sqrtsoftplus, local_comp_buf already contains sqrtsoftplus output computed earlier + // Now compute gradient: dy/dx = sigmoid(x) / (2 * y) + if (score_function == 2) { + apply_sqrtsoftplus_bwd_on_float(local_grad, local_comp_buf, local_act_from_fwd, num_experts, + lane_id); + __syncwarp(); + } // Write the grad_logits to the global mem for (int i = lane_id; i < num_experts; i += kThreadsPerWarp) { grad_logits[pos_offset + i] = local_grad[i]; @@ -433,17 +456,16 @@ __global__ void fused_topk_with_score_function_backward_kernel( template void fused_topk_with_score_function_backward_kernel_launcher( - const bool *routing_map, const DataType *intermediate_output, const DataType *grad_probs, + const bool *routing_map, const float *intermediate_output, const DataType *grad_probs, int num_tokens, int num_experts, int topk, bool use_pre_softmax, float scaling_factor, int score_function, DataType *grad_logits, cudaStream_t stream) { // Meta data for the kernel size_t num_token_per_block = kThreadsPerBlock / kThreadsPerWarp; size_t grid_size = (num_tokens + num_token_per_block - 1) / num_token_per_block; - size_t shared_memory_size = num_experts * num_token_per_block * sizeof(DataType) // grad_probs - + - num_experts * num_token_per_block * sizeof(DataType) // act_from_fwd - + num_experts * num_token_per_block * sizeof(DataType) // comp_buf - + num_experts * num_token_per_block * sizeof(bool); // routing_map + size_t shared_memory_size = num_experts * num_token_per_block * sizeof(float) // grad_probs + + num_experts * num_token_per_block * sizeof(float) // act_from_fwd + + num_experts * num_token_per_block * sizeof(float) // comp_buf + + num_experts * num_token_per_block * sizeof(bool); // routing_map fused_topk_with_score_function_backward_kernel <<>>( routing_map, intermediate_output, grad_probs, num_tokens, num_experts, topk, @@ -461,7 +483,7 @@ void fused_topk_with_score_function_backward(const Tensor &routing_map, grad_logits.data.dtype, DataType, fused_topk_with_score_function_backward_kernel_launcher( reinterpret_cast(routing_map.data.dptr), - reinterpret_cast(intermediate_output.data.dptr), + reinterpret_cast(intermediate_output.data.dptr), reinterpret_cast(grad_probs.data.dptr), num_tokens, num_experts, topk, use_pre_softmax, scaling_factor, score_function, reinterpret_cast(grad_logits.data.dptr), stream);); diff --git a/transformer_engine/common/fused_router/utils.h b/transformer_engine/common/fused_router/utils.h index 4ae0b467b5..e1484202e2 100644 --- a/transformer_engine/common/fused_router/utils.h +++ b/transformer_engine/common/fused_router/utils.h @@ -35,19 +35,19 @@ template __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncType type, int lane_id) { T (*reduce_func)(T, T); - double default_val = 0; + float default_val = 0.0f; if (type == ReduceFuncType::SUM) { reduce_func = sum; - default_val = 0; + default_val = 0.0f; } else if (type == ReduceFuncType::MAX) { reduce_func = max; - default_val = -std::numeric_limits::infinity(); + default_val = -std::numeric_limits::infinity(); } // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = lane_id < data_size ? static_cast(data_ptr[lane_id]) : default_val; + volatile float val = lane_id < data_size ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { val = reduce_func(val, data_ptr[i]); } @@ -62,31 +62,23 @@ __device__ inline T warp_reduce_on_shmem(T *data_ptr, int data_size, ReduceFuncT return T(val); } -template -__device__ inline void apply_sigmoid_on_float(DataType *scores, int data_size, int lane_id) { - for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(1.0f / (1.0f + exp(-static_cast(scores[i])))); - } -} - template __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int data_size, ReduceFuncType type, int lane_id) { T (*reduce_func)(T, T); - double default_val = 0; + float default_val = 0.0f; if (type == ReduceFuncType::SUM) { reduce_func = sum; - default_val = 0; + default_val = 0.0f; } else if (type == ReduceFuncType::MAX) { reduce_func = max; - default_val = -std::numeric_limits::infinity(); + default_val = -std::numeric_limits::infinity(); } // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread - volatile double val = - lane_id < data_size && mask[lane_id] ? static_cast(data_ptr[lane_id]) : default_val; + volatile float val = lane_id < data_size && mask[lane_id] ? data_ptr[lane_id] : default_val; for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { if (mask[i]) { val = reduce_func(val, data_ptr[i]); @@ -103,28 +95,70 @@ __device__ inline T masked_warp_reduce_on_shmem(T *data_ptr, bool *mask, int dat return T(val); } -template -__device__ inline void apply_sigmoid_bwd_on_float(DataType *grad, DataType *fwd_output, - int data_size, int lane_id) { +__device__ inline void apply_sigmoid_on_float(float *scores, int data_size, int lane_id) { for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - grad[i] = static_cast(grad[i]) * static_cast(fwd_output[i]) * - (1 - static_cast(fwd_output[i])); + scores[i] = 1.0f / (1.0f + expf(-scores[i])); } } -template -__device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_output, - DataType *comp_buf, bool *mask, int data_size, +__device__ inline void apply_sigmoid_bwd_on_float(float *grad, float *fwd_output, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + grad[i] = grad[i] * fwd_output[i] * (1.0f - fwd_output[i]); + } +} + +// sqrtsoftplus: y = sqrt(softplus(x)) = sqrt(log(1 + exp(x))) +__device__ inline void apply_sqrtsoftplus_on_float(float *scores, int data_size, int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = scores[i]; + // softplus(x) = log(1 + exp(x)), numerically stable version + // Matches PyTorch's Softplus(beta=1.0, threshold=20.0) + float softplus_val; + if (x > 20.0f) { + softplus_val = x; // for large x, softplus(x) ≈ x + } else { + softplus_val = log1pf(expf(x)); + } + scores[i] = sqrtf(softplus_val); + } +} + +// sqrtsoftplus backward: +// y = sqrt(softplus(x)) +// Matches PyTorch's Softplus(beta=1.0, threshold=20.0) +// We need the original logits (x) to compute the gradient +__device__ inline void apply_sqrtsoftplus_bwd_on_float(float *grad, float *fwd_output, + float *logits_buf, int data_size, + int lane_id) { + for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { + float x = logits_buf[i]; // original logit + float y = fwd_output[i]; // sqrtsoftplus output + float dy_dx; + if (x > 20.0f) { + // When softplus(x) = x, y = sqrt(x), dy/dx = 1/(2*y) + dy_dx = 1.0f / (2.0f * y + epsilon); + } else { + // When softplus(x) = log(1+exp(x)), dy/dx = sigmoid(x) / (2*y) + // where sigmoid(x) = 1 / (1 + exp(-x)) + float sigmoid_x = 1.0f / (1.0f + expf(-x)); + dy_dx = sigmoid_x / (2.0f * y + epsilon); + } + grad[i] = grad[i] * dy_dx; + } +} + +__device__ inline void apply_softmax_bwd_on_float(float *grad, float *fwd_output, float *comp_buf, + bool *mask, int data_size, int lane_id) { // Put the result of output * grad to the comp_buf for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { if (mask[i]) - comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + comp_buf[i] = grad[i] * fwd_output[i]; else comp_buf[i] = 0.0f; } else { - comp_buf[i] = static_cast(grad[i]) * static_cast(fwd_output[i]); + comp_buf[i] = grad[i] * fwd_output[i]; } } __syncwarp(); @@ -136,40 +170,34 @@ __device__ inline void apply_softmax_bwd_on_float(DataType *grad, DataType *fwd_ for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { if (mask) { if (mask[i]) - grad[i] = - static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); else grad[i] = 0.0f; } else { - grad[i] = - static_cast(fwd_output[i]) * (static_cast(grad[i]) - sum_Output_x_Grad); + grad[i] = fwd_output[i] * (grad[i] - sum_Output_x_Grad); } } } -template -__device__ inline void apply_softmax_on_float(DataType *scores, int data_size, int lane_id) { +__device__ inline void apply_softmax_on_float(float *scores, int data_size, int lane_id) { // 1. compute the max of value - float max_val = - static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id)); + float max_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::MAX, lane_id); // 2. value -> exp_value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(exp(static_cast(scores[i]) - max_val)); + scores[i] = expf(scores[i] - max_val); } __syncwarp(); // 3. compute the sum of exp_value - float sum_val = - static_cast(warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id)); + float sum_val = warp_reduce_on_shmem(scores, data_size, ReduceFuncType::SUM, lane_id); // 4. update the softmax value for (int i = lane_id; i < data_size; i += kThreadsPerWarp) { - scores[i] = static_cast(scores[i]) / sum_val; + scores[i] = scores[i] / sum_val; } __syncwarp(); } -template -__device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, int *topk_indices, - T *topk_scores, int lane_id) { +__device__ inline void naive_topk_and_mask(float *scores, int data_size, int topk, + int *topk_indices, float *topk_scores, int lane_id) { // Check if the index is masked by the later iteration auto is_masked = [&topk_indices](int k, int index) { if (k == 0) return false; @@ -183,16 +211,16 @@ __device__ inline void naive_topk_and_mask(T *scores, int data_size, int topk, i // After looping topk times, the topk_indices will be the topk indices for (int k = 0; k < topk; k++) { // Find the max value and its index - volatile double val = (lane_id < data_size && !is_masked(k, lane_id)) - ? static_cast(scores[lane_id]) - : -std::numeric_limits::infinity(); + volatile float val = (lane_id < data_size && !is_masked(k, lane_id)) + ? scores[lane_id] + : -std::numeric_limits::infinity(); volatile int index = (lane_id < data_size) ? lane_id : 0; // Some value is hanlded in local thread // Thread 0 is responsible for the: 0-th, 32-th, 64-th, 96-th ... // Reduce the value in local thread for (int i = lane_id + kThreadsPerWarp; i < data_size; i += kThreadsPerWarp) { - volatile double cur_val = (is_masked(k, i)) ? -std::numeric_limits::infinity() - : static_cast(scores[i]); + volatile float cur_val = + (is_masked(k, i)) ? -std::numeric_limits::infinity() : scores[i]; if (cur_val > val) { val = cur_val; index = i; diff --git a/transformer_engine/common/include/transformer_engine/fused_router.h b/transformer_engine/common/include/transformer_engine/fused_router.h index 1f026a703d..794880d324 100644 --- a/transformer_engine/common/include/transformer_engine/fused_router.h +++ b/transformer_engine/common/include/transformer_engine/fused_router.h @@ -23,8 +23,8 @@ extern "C" { * \param[in] num_groups Number of groups in grouped topk. * \param[in] group_topk Grouped topk value. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. - * \param[in] expert_bias Expert bias. (Only used at the sigmoid case) + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. + * \param[in] expert_bias Expert bias. (Used at the sigmoid/sqrtsoftplus cases) * \param[out] probs Output tensor for probabilities. * \param[out] routing_map Output tensor for routing map. * \param[out] intermediate_output Output tensor for intermediate output. (Softmax/sigmoid output) @@ -46,7 +46,7 @@ void nvte_fused_topk_with_score_function_forward( * \param[in] topk Topk value. * \param[in] use_pre_softmax Whether to use softmax before topk. * \param[in] scaling_factor Scaling factor. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ @@ -63,7 +63,7 @@ void nvte_fused_topk_with_score_function_backward(const NVTETensor routing_map, * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] scores Output tensor for scores. * \param[in] routing_map Routing map. * \param[in] intermediate_output Intermediate output from the forward pass. (Softmax/sigmoid output) @@ -82,7 +82,7 @@ void nvte_fused_score_for_moe_aux_loss_forward(const NVTETensor logits, int num_ * \param[in] num_tokens Number of tokens. * \param[in] num_experts Number of experts. * \param[in] topk Topk value. - * \param[in] score_function Score function, 0: sigmoid, 1: softmax. + * \param[in] score_function Score function, 0: sigmoid, 1: softmax, 2: sqrtsoftplus. * \param[out] grad_logits Gradient of logits. * \param[in] stream CUDA stream used for the operation. */ diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0ea3d6b78..cb98cce060 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -27,23 +27,22 @@ namespace transformer_engine::pytorch { **************************************************************************************************/ std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, - c10::optional group_topk, c10::optional scaling_factor, std::string score_function, - c10::optional expert_bias); + at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, + std::optional group_topk, std::optional scaling_factor, std::string score_function, + std::optional expert_bias); -at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, - at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, bool use_pre_softmax, - c10::optional scaling_factor, - std::string score_function); +void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, bool use_pre_softmax, + std::optional scaling_factor, + std::string score_function); std::tuple fused_score_for_moe_aux_loss_fwd( at::Tensor logits, int topk, std::string score_function); -at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, std::string score_function); +void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, std::string score_function); std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, at::Tensor tokens_per_expert, diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1e907d9bc0..f84c5de5e8 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -325,19 +325,20 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { &transformer_engine::pytorch::fused_topk_with_score_function_fwd, py::arg("logits"), py::arg("topk"), py::arg("use_pre_softmax"), py::arg("num_groups"), py::arg("group_topk"), py::arg("scaling_factor"), py::arg("score_function"), py::arg("expert_bias"), - "Fused topk softmax fwd"); + "Fused topk with score function fwd"); m.def("fused_topk_with_score_function_bwd", &transformer_engine::pytorch::fused_topk_with_score_function_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("routing_map"), py::arg("intermediate_output"), - py::arg("grad_probs"), py::arg("topk"), py::arg("use_pre_softmax"), - py::arg("scaling_factor"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("grad_probs"), py::arg("grad_logits"), py::arg("topk"), py::arg("use_pre_softmax"), + py::arg("scaling_factor"), py::arg("score_function"), "Fused topk with score function bwd"); m.def("fused_score_for_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_fwd, py::arg("logits"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax fwd"); + py::arg("topk"), py::arg("score_function"), "Fused aux loss with score function fwd"); m.def("fused_score_for_moe_aux_loss_bwd", &transformer_engine::pytorch::fused_score_for_moe_aux_loss_bwd, py::arg("num_tokens"), py::arg("num_experts"), py::arg("intermediate_output"), py::arg("grad_scores"), - py::arg("topk"), py::arg("score_function"), "Fused topk softmax bwd"); + py::arg("grad_logits"), py::arg("topk"), py::arg("score_function"), + "Fused aux loss with score function bwd"); m.def("fused_moe_aux_loss_fwd", &transformer_engine::pytorch::fused_moe_aux_loss_fwd, py::arg("probs"), py::arg("tokens_per_expert"), py::arg("total_num_tokens"), py::arg("num_experts"), py::arg("num_rows"), py::arg("num_cols"), py::arg("topk"), diff --git a/transformer_engine/pytorch/csrc/extensions/router.cpp b/transformer_engine/pytorch/csrc/extensions/router.cpp index 2ae0d648a1..94625c0f12 100644 --- a/transformer_engine/pytorch/csrc/extensions/router.cpp +++ b/transformer_engine/pytorch/csrc/extensions/router.cpp @@ -9,12 +9,13 @@ namespace transformer_engine::pytorch { -static std::map score_function_map = {{"sigmoid", 0}, {"softmax", 1}}; +static std::map score_function_map = { + {"sigmoid", 0}, {"softmax", 1}, {"sqrtsoftplus", 2}}; std::tuple fused_topk_with_score_function_fwd( - at::Tensor logits, int topk, bool use_pre_softmax, c10::optional num_groups, - c10::optional group_topk, c10::optional scaling_factor, std::string score_function, - c10::optional expert_bias) { + at::Tensor logits, int topk, bool use_pre_softmax, std::optional num_groups, + std::optional group_topk, std::optional scaling_factor, std::string score_function, + std::optional expert_bias) { int num_tokens = logits.size(0); int num_experts = logits.size(1); // Check if the input is valid @@ -22,13 +23,16 @@ std::tuple fused_topk_with_score_function_fw "num_tokens and num_experts must be greater than 0"); // Expert bias only happens at the sigmoid case if (expert_bias.has_value()) { - TORCH_CHECK(score_function == "sigmoid", - "score_function must be sigmoid when expert_bias is not None"); + TORCH_CHECK(score_function == "sigmoid" || score_function == "sqrtsoftplus", + "score_function must be sigmoid or sqrtsoftplus when expert_bias is not None"); + TORCH_CHECK(expert_bias.value().scalar_type() == at::kFloat, + "expert_bias must be a float32 tensor"); } // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); - if (score_function == "sigmoid") { + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); + if (score_function == "sigmoid" || score_function == "sqrtsoftplus") { use_pre_softmax = false; // Pre-softmax only happens at the softmax case } @@ -44,7 +48,7 @@ std::tuple fused_topk_with_score_function_fw at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); // Intermediate output is used to store the output of the softmax/sigmoid function at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); auto logits_cu = makeTransformerEngineTensor(logits); auto probs_cu = makeTransformerEngineTensor(probs); @@ -64,18 +68,14 @@ std::tuple fused_topk_with_score_function_fw return std::make_tuple(probs, routing_map, intermediate_output); } -at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, - at::Tensor routing_map, - at::Tensor intermediate_output, at::Tensor grad_probs, - int topk, bool use_pre_softmax, - c10::optional scaling_factor, - std::string score_function) { +void fused_topk_with_score_function_bwd(int num_tokens, int num_experts, at::Tensor routing_map, + at::Tensor intermediate_output, at::Tensor grad_probs, + at::Tensor grad_logits, int topk, bool use_pre_softmax, + std::optional scaling_factor, + std::string score_function) { // Get the value of the parameters auto scaling_factor_value = scaling_factor.has_value() ? scaling_factor.value() : 1.0f; auto score_function_value = score_function_map[score_function]; - // Init the output tensor - at::Tensor grad_logits = at::empty( - {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); auto routing_map_cu = makeTransformerEngineTensor(routing_map); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); @@ -86,8 +86,6 @@ at::Tensor fused_topk_with_score_function_bwd(int num_tokens, int num_experts, routing_map_cu.data(), intermediate_output_cu.data(), grad_probs_cu.data(), num_tokens, num_experts, topk, use_pre_softmax, scaling_factor_value, score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); - - return grad_logits; } std::tuple fused_score_for_moe_aux_loss_fwd( @@ -99,17 +97,17 @@ std::tuple fused_score_for_moe_aux_loss_fwd( "num_tokens and num_experts must be greater than 0"); TORCH_CHECK(topk > 0, "topk must be greater than 0"); // Check if the score function is valid - TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid", - "score_function must be softmax or sigmoid for router fusion"); + TORCH_CHECK(score_function == "softmax" || score_function == "sigmoid" || + score_function == "sqrtsoftplus", + "score_function must be softmax, sigmoid or sqrtsoftplus for router fusion"); int score_function_value = score_function_map[score_function]; // Construct the output tensor - at::Tensor scores = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::Tensor scores = at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); at::Tensor routing_map = at::empty({num_tokens, num_experts}, at::dtype(at::kBool).device(at::kCUDA)); at::Tensor intermediate_output = - at::empty({num_tokens, num_experts}, at::dtype(logits.scalar_type()).device(at::kCUDA)); + at::empty({num_tokens, num_experts}, at::dtype(at::kFloat).device(at::kCUDA)); auto logits_cu = makeTransformerEngineTensor(logits); auto scores_cu = makeTransformerEngineTensor(scores); @@ -123,14 +121,12 @@ std::tuple fused_score_for_moe_aux_loss_fwd( return std::make_tuple(scores, routing_map, intermediate_output); } -at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, - at::Tensor intermediate_output, at::Tensor grad_scores, - int topk, std::string score_function) { +void fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, + at::Tensor intermediate_output, at::Tensor grad_scores, + at::Tensor grad_logits, int topk, + std::string score_function) { // Get the value of the parameters int score_function_value = score_function_map[score_function]; - // Init the output tensor - at::Tensor grad_logits = at::empty( - {num_tokens, num_experts}, at::dtype(intermediate_output.scalar_type()).device(at::kCUDA)); auto intermediate_output_cu = makeTransformerEngineTensor(intermediate_output); auto grad_scores_cu = makeTransformerEngineTensor(grad_scores); @@ -139,8 +135,6 @@ at::Tensor fused_score_for_moe_aux_loss_bwd(int num_tokens, int num_experts, nvte_fused_score_for_moe_aux_loss_backward( intermediate_output_cu.data(), grad_scores_cu.data(), num_tokens, num_experts, topk, score_function_value, grad_logits_cu.data(), at::cuda::getCurrentCUDAStream()); - - return grad_logits; } std::tuple fused_moe_aux_loss_fwd(at::Tensor probs, diff --git a/transformer_engine/pytorch/router.py b/transformer_engine/pytorch/router.py index 52d1d9d6ca..ee2c420473 100644 --- a/transformer_engine/pytorch/router.py +++ b/transformer_engine/pytorch/router.py @@ -3,7 +3,18 @@ # See LICENSE for license information. """ Fused functions used in the MoE router + +Precision Notes: +- FP64 is currently not supported. +- Inputs are casted into FP32 when loading from global memory. +- All the math/calculations/accumulations are in FP32 in the kernels. +- "scores" is always in FP32 (match the MCore implementation). +- "intermediate_output" is always in FP32 for better backward precision. +- Only cast to low-precision when necessary and the casting only happens in writing to + global memory. For example, the gradient is required to have the same dtype as the input. """ +from typing import Optional + import torch import transformer_engine_torch as tex @@ -11,7 +22,7 @@ class FusedTopkScoreFunction(torch.autograd.Function): """ Fused Topk with Score Function router. - Currently, only support softmax and sigmoid. + Currently, support "softmax", "sigmoid" and "sqrtsoftplus". """ @staticmethod @@ -20,11 +31,11 @@ def forward( logits: torch.Tensor, topk: int, use_pre_softmax: bool, - num_groups: int, - group_topk: int, - scaling_factor: float, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], score_function: str, - expert_bias: torch.Tensor, + expert_bias: Optional[torch.Tensor], ): # pylint: disable=missing-function-docstring # Save the shape of the logits @@ -52,6 +63,7 @@ def forward( ctx.topk = topk ctx.scaling_factor = scaling_factor ctx.score_function = score_function + ctx.logits_dtype = logits.dtype return probs, routing_map @staticmethod @@ -62,12 +74,16 @@ def backward(ctx, grad_probs, _): tensor_shape = grad_probs.shape # Adjust the shape of the grad_probs to 2D shape grad_probs = grad_probs.contiguous().view(-1, tensor_shape[-1]) - grad_logits = tex.fused_topk_with_score_function_bwd( + grad_logits = torch.empty( + (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_probs.device + ) + tex.fused_topk_with_score_function_bwd( ctx.num_tokens, ctx.num_experts, routing_map, intermediate_output, grad_probs, + grad_logits, ctx.topk, ctx.use_pre_softmax, ctx.scaling_factor, @@ -82,37 +98,37 @@ def fused_topk_with_score_function( logits: torch.Tensor, topk: int, use_pre_softmax: bool, - num_groups: int, - group_topk: int, - scaling_factor: float, + num_groups: Optional[int], + group_topk: Optional[int], + scaling_factor: Optional[float], score_function: str, - expert_bias: torch.Tensor, + expert_bias: Optional[torch.Tensor], ): """ Fused topk with score function router. Parameters ---------- - logits : torch.Tensor + logits : torch.Tensor in fp32/bf16/fp16 topk : int use_pre_softmax : bool - if enabled, the computation order: softmax -> topk - num_groups : int + if enabled, the computation order: softmax -> topk. + num_groups : int, optional used in the group topk - group_topk : int + group_topk : int, optional used in the group topk - scaling_factor : float + scaling_factor : float, optional score_function : str - currently only support softmax and sigmoid - expert_bias : torch.Tensor - could be used in the sigmoid + currently support "softmax", "sigmoid" and "sqrtsoftplus". + expert_bias : torch.Tensor in fp32, optional + could be used with the sigmoid/sqrtsoftplus score functions. Returns ------- - probs : torch.Tensor - routing_map : torch.Tensor + probs : torch.Tensor in the same dtype as the "logits". + routing_map : torch.Tensor in bool. """ if logits.dtype == torch.float64: - raise ValueError("Current TE does not support float64 router type") + raise ValueError("Current TE does not support float64 router type.") return FusedTopkScoreFunction.apply( logits, topk, @@ -154,6 +170,7 @@ def forward( ctx.score_function = score_function ctx.num_tokens = num_tokens ctx.num_experts = num_experts + ctx.logits_dtype = logits.dtype return routing_map, scores @staticmethod @@ -164,11 +181,15 @@ def backward(ctx, _, grad_scores): tensor_shape = grad_scores.shape # Adjust the shape of the grad_scores to 2D shape grad_scores = grad_scores.contiguous().view(-1, tensor_shape[-1]) - grad_logits = tex.fused_score_for_moe_aux_loss_bwd( + grad_logits = torch.empty( + (ctx.num_tokens, ctx.num_experts), dtype=ctx.logits_dtype, device=grad_scores.device + ) + tex.fused_score_for_moe_aux_loss_bwd( num_tokens=ctx.num_tokens, num_experts=ctx.num_experts, intermediate_output=intermediate_output, grad_scores=grad_scores, + grad_logits=grad_logits, topk=ctx.topk, score_function=ctx.score_function, ) @@ -186,15 +207,15 @@ def fused_compute_score_for_moe_aux_loss( Fused compute scores for MoE aux loss, subset of the fused_topk_with_score_function. Parameters ---------- - logits : torch.Tensor + logits : torch.Tensor in fp32/bf16/fp16 topk : int score_function : str - currently only support softmax and sigmoid + currently support "softmax", "sigmoid" and "sqrtsoftplus". Returns ------- - routing_map : torch.Tensor - scores : torch.Tensor + routing_map : torch.Tensor in bool + scores : torch.Tensor in fp32 """ return FusedComputeScoresForMoEAuxLoss.apply(logits, topk, score_function) @@ -253,23 +274,24 @@ def fused_moe_aux_loss( num_experts: int, topk: int, coeff: float, -): +) -> torch.Tensor: """ Fused MoE aux loss. Parameters ---------- - probs : torch.Tensor - tokens_per_expert : torch.Tensor - the number of tokens per expert + probs : torch.Tensor in fp32/bf16/fp16 + tokens_per_expert : torch.Tensor in int32/int64/fp32/bf16 + the number of tokens per expert. total_num_tokens : int - the total number of tokens, involved in the aux loss calculation + the total number of tokens used in the aux loss calculation. num_experts : int topk : int coeff : float - the coefficient of the aux loss + the coefficient of the aux loss. Returns ------- - aux_loss : torch.scalar + aux_loss : torch.Tensor. + A scalar tensor in the same dtype as the "probs". """ return FusedAuxLoss.apply(probs, tokens_per_expert, total_num_tokens, num_experts, topk, coeff)