Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
83 changes: 57 additions & 26 deletions tests/pytorch/test_fused_router.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ def group_limited_topk(


# Pytorch-based topk softmax/sigmoid
def topk_softmax_sigmoid_pytorch(
def topk_score_function_pytorch(
logits: torch.Tensor,
topk: int,
use_pre_softmax: bool = False,
Expand All @@ -74,17 +74,20 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):

if score_function == "softmax":
if use_pre_softmax:
scores = torch.softmax(logits, dim=-1, dtype=torch.float32).type_as(logits)
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
probs, top_indices = compute_topk(scores, topk, num_groups, group_topk)
else:
scores, top_indices = compute_topk(logits, topk, num_groups, group_topk)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32).type_as(logits)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits.float()).type_as(logits)
probs = torch.softmax(scores, dim=-1, dtype=torch.float32)
elif score_function in ("sigmoid", "sqrtsoftplus"):
if score_function == "sigmoid":
scores = torch.sigmoid(logits.float())
else:
scores = torch.nn.functional.softplus(logits.float()).sqrt()
if expert_bias is not None:
scores_for_routing = scores + expert_bias
_, top_indices = compute_topk(scores_for_routing, topk, num_groups, group_topk)
scores = torch.gather(scores, dim=1, index=top_indices).type_as(logits)
scores = torch.gather(scores, dim=1, index=top_indices)
else:
scores, top_indices = compute_topk(scores, topk, num_groups, group_topk)
probs = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
Expand All @@ -94,6 +97,8 @@ def compute_topk(scores, topk, num_groups=None, group_topk=None):
if scaling_factor:
probs = probs * scaling_factor

probs = probs.type_as(logits)

topk_masked_gates = torch.zeros_like(logits).scatter(1, top_indices, probs)
topk_map = torch.zeros_like(logits).int().scatter(1, top_indices, 1).bool()

Expand All @@ -107,7 +112,10 @@ def compute_scores_for_aux_loss_pytorch(
if score_function == "softmax":
scores = torch.softmax(logits, dim=-1, dtype=torch.float32)
elif score_function == "sigmoid":
scores = torch.sigmoid(logits)
scores = torch.sigmoid(logits.float())
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
elif score_function == "sqrtsoftplus":
scores = torch.nn.functional.softplus(logits.float()).sqrt()
scores = scores / (scores.sum(dim=-1, keepdim=True) + 1e-20) if topk > 1 else scores
else:
raise ValueError(f"Invalid score_function: {score_function}")
Expand Down Expand Up @@ -146,8 +154,9 @@ def run_comparison(
enable_bias,
):
# Set some parameters
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values,
# which would cause precision loss when adding/subtracting expert bias in float32.
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
Expand All @@ -165,8 +174,8 @@ def run_comparison(
)
logits = logits.view(num_tokens, num_experts)
logits.requires_grad = True
if enable_bias and score_function == "sigmoid":
expert_bias = torch.arange(num_experts, device="cuda") * 0.1
if enable_bias and score_function in ("sigmoid", "sqrtsoftplus"):
expert_bias = torch.arange(num_experts, device="cuda", dtype=dtype) * 0.1
expert_bias = torch.flip(expert_bias, dims=[0])
expert_bias.requires_grad = True
else:
Expand All @@ -183,7 +192,7 @@ def run_comparison(

# Run the original implementation
# We do not support the capacity factor case
probs, routing_map = topk_softmax_sigmoid_pytorch(
probs, routing_map = topk_score_function_pytorch(
logits=logits,
topk=topk,
use_pre_softmax=use_pre_softmax,
Expand Down Expand Up @@ -252,6 +261,37 @@ def test_topk_sigmoid(
)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 8992])
@pytest.mark.parametrize("num_experts", [128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("group_topk", [None, 4])
@pytest.mark.parametrize("scaling_factor", [None, 1.2])
@pytest.mark.parametrize("enable_bias", [True, False])
def test_topk_sqrtsoftplus(
dtype,
num_tokens,
num_experts,
topk,
group_topk,
scaling_factor,
enable_bias,
):
num_groups = 8 if group_topk else None
run_comparison(
dtype=dtype,
num_tokens=num_tokens,
num_experts=num_experts,
topk=topk,
use_pre_softmax=False,
num_groups=num_groups,
group_topk=group_topk,
scaling_factor=scaling_factor,
score_function="sqrtsoftplus",
enable_bias=enable_bias,
)


@pytest.mark.parametrize("dtype", [torch.float32])
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [128, 32])
Expand Down Expand Up @@ -287,10 +327,10 @@ def test_topk_softmax(
@pytest.mark.parametrize("num_tokens", [2048, 7168, 14234])
@pytest.mark.parametrize("num_experts", [256, 128, 32])
@pytest.mark.parametrize("topk", [4, 8])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid"])
@pytest.mark.parametrize("score_function", ["softmax", "sigmoid", "sqrtsoftplus"])
def test_fused_scores_for_aux_loss(dtype, num_tokens, num_experts, topk, score_function):
if score_function == "sigmoid":
# Construct the special logits to avoid inf in the sigmoid function
if score_function in ("sigmoid", "sqrtsoftplus"):
# Construct logits with a narrow range to avoid very small activation values
offset = torch.arange(-num_tokens // 2, num_tokens // 2, dtype=dtype, device="cuda") * 1e-4
logits = (
torch.arange(-num_experts // 2, num_experts // 2, device="cuda", dtype=dtype) * 1e-2
Expand Down Expand Up @@ -396,15 +436,6 @@ def profile_topk_softmax(
test_topk_softmax(
torch.float32, num_tokens, num_experts, topk, use_pre_softmax, group_topk, scaling_factor
)


if __name__ == "__main__":
test_topk_softmax(
dtype=torch.float32,
num_tokens=1024,
num_experts=128,
topk=4,
use_pre_softmax=False,
group_topk=None,
scaling_factor=None,
test_topk_sqrtsoftplus(
torch.float32, num_tokens, num_experts, topk, group_topk, scaling_factor, enable_bias
)
10 changes: 5 additions & 5 deletions transformer_engine/common/fused_router/fused_moe_aux_loss.cu
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
namespace transformer_engine {

// Using Double to hanld all the calculations
using CompType = double;
using CompType = float;

template <typename DataType, typename IndexType>
__global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
Expand Down Expand Up @@ -98,7 +98,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
aux_loss[0] = static_cast<DataType>(intermediate_result * C_coeff);
Const_buf[0] = C_coeff;
}
}
Expand Down Expand Up @@ -154,7 +154,7 @@ __global__ void fused_moe_aux_loss_forward_kernel(const DataType* probs,
* Section: Compute the aux_loss
*/
float C_coeff = (num_experts * coeff) / topk / total_num_tokens / total_num_tokens;
aux_loss[0] = static_cast<DataType>(static_cast<double>(intermediate_result) * C_coeff);
aux_loss[0] = static_cast<DataType>(intermediate_result * C_coeff);
Const_buf[0] = C_coeff;
}
}
Expand Down Expand Up @@ -229,8 +229,8 @@ __global__ void fused_moe_aux_loss_backward_kernel(const float* Const_buf,
// Loop: for all positions in each row
for (int i = lane_id; i < num_cols; i += kThreadsPerWarp) {
float C_coeff = Const_buf[0];
double tokens_per_expert_i = static_cast<double>(tokens_per_expert[i]);
double grad_aux_loss_value = static_cast<double>(grad_aux_loss[0]);
CompType tokens_per_expert_i = static_cast<CompType>(tokens_per_expert[i]);
CompType grad_aux_loss_value = static_cast<CompType>(grad_aux_loss[0]);
// Loop: for all rows
for (int j = global_warp_id; j < num_rows; j += global_warp_num) {
grad_probs[j * num_cols + i] = C_coeff * tokens_per_expert_i * grad_aux_loss_value;
Expand Down
Loading
Loading