Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
192 changes: 192 additions & 0 deletions docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Mixtral MoE with Transformer Engine\n",
"\n",
"## Step 1: Wrap MoE Layers with TE Modules\n",
"\n",
"This notebook demonstrates wrapping Mixtral's MoE FFN layers with Transformer Engine's `GroupedLinear` for efficient expert processing.\n",
"\n",
"Reference: `src/transformers/models/mixtral/modular_mixtral.py`"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import torch.nn as nn\n",
"import torch.nn.functional as F\n",
"from typing import Optional, Tuple\n",
"import transformer_engine.pytorch as te\n",
"from transformer_engine.pytorch import GroupedLinear\n",
"\n",
"class TEMixtralSparseMoeBlock(nn.Module):\n",
" \"\"\"\n",
" Transformer Engine optimized MoE block using GroupedLinear for parallel expert processing.\n",
" \n",
" Key improvements:\n",
" 1. Use te.GroupedLinear to process all experts in a single batched GEMM\n",
" 2. Use te.moe_permute/unpermute for efficient token routing\n",
" \"\"\"\n",
" def __init__(self, config):\n",
" super().__init__()\n",
" self.hidden_dim = config.hidden_size\n",
" self.ffn_dim = config.intermediate_size\n",
" self.num_experts = config.num_local_experts\n",
" self.top_k = config.num_experts_per_tok\n",
" \n",
" # Keep HuggingFace router (not in critical path for performance)\n",
" self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)\n",
" \n",
" # Replace individual expert layers with GroupedLinear\n",
" # GroupedLinear processes all experts in parallel with a single GEMM\n",
" # For SwiGLU: w1 (gate) and w3 (up) are combined, then w2 (down)\n",
" \n",
" # w1 and w3 combined (gate_proj + up_proj)\n",
" self.experts_gate_up = GroupedLinear(\n",
" num_gemms=self.num_experts,\n",
" in_features=self.hidden_dim,\n",
" out_features=2 * self.ffn_dim, # 2x for gate and up proj combined\n",
" bias=False,\n",
" device='cuda'\n",
" )\n",
" \n",
" # w2 (down_proj)\n",
" self.experts_down = GroupedLinear(\n",
" num_gemms=self.num_experts,\n",
" in_features=self.ffn_dim,\n",
" out_features=self.hidden_dim,\n",
" bias=False,\n",
" device='cuda'\n",
" )\n",
" \n",
" def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n",
" \"\"\"\n",
" Args:\n",
" hidden_states: [batch_size, sequence_length, hidden_dim]\n",
" \n",
" Returns:\n",
" final_hidden_states: [batch_size, sequence_length, hidden_dim]\n",
" router_logits: [batch_size * sequence_length, num_experts]\n",
" \"\"\"\n",
" batch_size, sequence_length, hidden_dim = hidden_states.shape\n",
" hidden_states_flat = hidden_states.view(-1, hidden_dim) # [num_tokens, hidden_dim]\n",
" num_tokens = hidden_states_flat.shape[0]\n",
" \n",
" # Router: Get expert assignments for each token\n",
" router_logits = self.gate(hidden_states_flat)\n",
" routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n",
" routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n",
" routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n",
" routing_weights = routing_weights.to(hidden_states.dtype)\n",
" \n",
" # Permute tokens by expert assignment\n",
" # moe_permute groups tokens going to the same expert together\n",
" permuted_tokens, row_id_map = te.moe_permute(\n",
" hidden_states_flat,\n",
" selected_experts.to(torch.int32),\n",
" num_out_tokens=None, # Auto-calculate\n",
" max_token_num=num_tokens\n",
Comment on lines +91 to +95
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Setting num_out_tokens to None is fine for auto-calculation, but when using top_k > 1, the expected output token count should be num_tokens times top_k since each token is routed to multiple experts.

" )\n",
" \n",
" # Calculate m_splits: number of tokens assigned to each expert\n",
" m_splits = []\n",
" for expert_idx in range(self.num_experts):\n",
" expert_mask = (selected_experts == expert_idx).any(dim=-1)\n",
" m_splits.append(expert_mask.sum().item() * self.top_k)\n",
Comment on lines +98 to +102
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Logic error in m_splits calculation. The current approach counts tokens incorrectly by multiplying by top_k after already considering all expert assignments.

The issue: expert_mask already captures ALL tokens that selected this expert (across all top-k positions), so multiplying by self.top_k double-counts.

For example, if token 0 selects experts [1, 3] and token 1 selects experts [1, 2], then for expert 1: expert_mask will be [True, True] (sum=2). Multiplying by top_k=2 gives 4, but only 2 tokens actually go to expert 1.

Suggested change
" # Calculate m_splits: number of tokens assigned to each expert\n",
" m_splits = []\n",
" for expert_idx in range(self.num_experts):\n",
" expert_mask = (selected_experts == expert_idx).any(dim=-1)\n",
" m_splits.append(expert_mask.sum().item() * self.top_k)\n",
# Calculate m_splits: number of tokens assigned to each expert
m_splits = []
for expert_idx in range(self.num_experts):
expert_mask = (selected_experts == expert_idx).any(dim=-1)
m_splits.append(expert_mask.sum().item())

" \n",
" # Process all experts in parallel using GroupedLinear\n",
" # Gate and Up projection (combined)\n",
" intermediate = self.experts_gate_up(permuted_tokens, m_splits=m_splits)\n",
" \n",
" # Apply SwiGLU activation: silu(gate) * up\n",
" gate, up = intermediate.chunk(2, dim=-1)\n",
" intermediate_act = F.silu(gate) * up\n",
" \n",
" # Down projection\n",
" expert_outputs = self.experts_down(intermediate_act, m_splits=m_splits)\n",
" \n",
" # Unpermute tokens back to original order and apply routing weights\n",
" final_hidden_states = te.moe_unpermute(\n",
" expert_outputs,\n",
" row_id_map,\n",
" probs=routing_weights\n",
" )\n",
" \n",
" final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n",
" return final_hidden_states, router_logits"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Test the Implementation"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Create a mock config for testing\n",
"class MixtralConfig:\n",
" hidden_size = 4096\n",
" intermediate_size = 14336\n",
" num_local_experts = 8\n",
" num_experts_per_tok = 2\n",
"\n",
"config = MixtralConfig()\n",
"\n",
"# Initialize TE-optimized MoE block\n",
"te_moe_block = TEMixtralSparseMoeBlock(config).cuda()\n",
"\n",
"# Test with sample input\n",
"batch_size, seq_len = 2, 16\n",
"hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device='cuda', dtype=torch.bfloat16)\n",
"\n",
"# Forward pass\n",
"with torch.no_grad():\n",
" output, router_logits = te_moe_block(hidden_states)\n",
" \n",
"print(f\"Input shape: {hidden_states.shape}\")\n",
"print(f\"Output shape: {output.shape}\")\n",
"print(f\"Router logits shape: {router_logits.shape}\")\n",
"print(f\"Output dtype: {output.dtype}\")\n",
"print(\"✓ TE-optimized MoE block working correctly!\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Next: Weight Mapping and Integration\n",
"\n",
"To integrate with HuggingFace Mixtral models, you need to:\n",
"\n",
"1. Map weights from HF `MixtralSparseMoeBlock` to `TEMixtralSparseMoeBlock`\n",
"2. Use monkey-patching to replace HF layers during model loading\n",
"3. Implement weight loading from HF checkpoints"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": []
}
],
"metadata": {
"language_info": {
"name": "python"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
Loading