diff --git a/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb b/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb new file mode 100644 index 0000000000..2d61be19cd --- /dev/null +++ b/docs/examples/te_mixtral/tutorial_accelerate_hf_mixtral_with_te.ipynb @@ -0,0 +1,192 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Mixtral MoE with Transformer Engine\n", + "\n", + "## Step 1: Wrap MoE Layers with TE Modules\n", + "\n", + "This notebook demonstrates wrapping Mixtral's MoE FFN layers with Transformer Engine's `GroupedLinear` for efficient expert processing.\n", + "\n", + "Reference: `src/transformers/models/mixtral/modular_mixtral.py`" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from typing import Optional, Tuple\n", + "import transformer_engine.pytorch as te\n", + "from transformer_engine.pytorch import GroupedLinear\n", + "\n", + "class TEMixtralSparseMoeBlock(nn.Module):\n", + " \"\"\"\n", + " Transformer Engine optimized MoE block using GroupedLinear for parallel expert processing.\n", + " \n", + " Key improvements:\n", + " 1. Use te.GroupedLinear to process all experts in a single batched GEMM\n", + " 2. Use te.moe_permute/unpermute for efficient token routing\n", + " \"\"\"\n", + " def __init__(self, config):\n", + " super().__init__()\n", + " self.hidden_dim = config.hidden_size\n", + " self.ffn_dim = config.intermediate_size\n", + " self.num_experts = config.num_local_experts\n", + " self.top_k = config.num_experts_per_tok\n", + " \n", + " # Keep HuggingFace router (not in critical path for performance)\n", + " self.gate = nn.Linear(self.hidden_dim, self.num_experts, bias=False)\n", + " \n", + " # Replace individual expert layers with GroupedLinear\n", + " # GroupedLinear processes all experts in parallel with a single GEMM\n", + " # For SwiGLU: w1 (gate) and w3 (up) are combined, then w2 (down)\n", + " \n", + " # w1 and w3 combined (gate_proj + up_proj)\n", + " self.experts_gate_up = GroupedLinear(\n", + " num_gemms=self.num_experts,\n", + " in_features=self.hidden_dim,\n", + " out_features=2 * self.ffn_dim, # 2x for gate and up proj combined\n", + " bias=False,\n", + " device='cuda'\n", + " )\n", + " \n", + " # w2 (down_proj)\n", + " self.experts_down = GroupedLinear(\n", + " num_gemms=self.num_experts,\n", + " in_features=self.ffn_dim,\n", + " out_features=self.hidden_dim,\n", + " bias=False,\n", + " device='cuda'\n", + " )\n", + " \n", + " def forward(self, hidden_states: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:\n", + " \"\"\"\n", + " Args:\n", + " hidden_states: [batch_size, sequence_length, hidden_dim]\n", + " \n", + " Returns:\n", + " final_hidden_states: [batch_size, sequence_length, hidden_dim]\n", + " router_logits: [batch_size * sequence_length, num_experts]\n", + " \"\"\"\n", + " batch_size, sequence_length, hidden_dim = hidden_states.shape\n", + " hidden_states_flat = hidden_states.view(-1, hidden_dim) # [num_tokens, hidden_dim]\n", + " num_tokens = hidden_states_flat.shape[0]\n", + " \n", + " # Router: Get expert assignments for each token\n", + " router_logits = self.gate(hidden_states_flat)\n", + " routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)\n", + " routing_weights, selected_experts = torch.topk(routing_weights, self.top_k, dim=-1)\n", + " routing_weights /= routing_weights.sum(dim=-1, keepdim=True)\n", + " routing_weights = routing_weights.to(hidden_states.dtype)\n", + " \n", + " # Permute tokens by expert assignment\n", + " # moe_permute groups tokens going to the same expert together\n", + " permuted_tokens, row_id_map = te.moe_permute(\n", + " hidden_states_flat,\n", + " selected_experts.to(torch.int32),\n", + " num_out_tokens=None, # Auto-calculate\n", + " max_token_num=num_tokens\n", + " )\n", + " \n", + " # Calculate m_splits: number of tokens assigned to each expert\n", + " m_splits = []\n", + " for expert_idx in range(self.num_experts):\n", + " expert_mask = (selected_experts == expert_idx).any(dim=-1)\n", + " m_splits.append(expert_mask.sum().item() * self.top_k)\n", + " \n", + " # Process all experts in parallel using GroupedLinear\n", + " # Gate and Up projection (combined)\n", + " intermediate = self.experts_gate_up(permuted_tokens, m_splits=m_splits)\n", + " \n", + " # Apply SwiGLU activation: silu(gate) * up\n", + " gate, up = intermediate.chunk(2, dim=-1)\n", + " intermediate_act = F.silu(gate) * up\n", + " \n", + " # Down projection\n", + " expert_outputs = self.experts_down(intermediate_act, m_splits=m_splits)\n", + " \n", + " # Unpermute tokens back to original order and apply routing weights\n", + " final_hidden_states = te.moe_unpermute(\n", + " expert_outputs,\n", + " row_id_map,\n", + " probs=routing_weights\n", + " )\n", + " \n", + " final_hidden_states = final_hidden_states.reshape(batch_size, sequence_length, hidden_dim)\n", + " return final_hidden_states, router_logits" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Test the Implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Create a mock config for testing\n", + "class MixtralConfig:\n", + " hidden_size = 4096\n", + " intermediate_size = 14336\n", + " num_local_experts = 8\n", + " num_experts_per_tok = 2\n", + "\n", + "config = MixtralConfig()\n", + "\n", + "# Initialize TE-optimized MoE block\n", + "te_moe_block = TEMixtralSparseMoeBlock(config).cuda()\n", + "\n", + "# Test with sample input\n", + "batch_size, seq_len = 2, 16\n", + "hidden_states = torch.randn(batch_size, seq_len, config.hidden_size, device='cuda', dtype=torch.bfloat16)\n", + "\n", + "# Forward pass\n", + "with torch.no_grad():\n", + " output, router_logits = te_moe_block(hidden_states)\n", + " \n", + "print(f\"Input shape: {hidden_states.shape}\")\n", + "print(f\"Output shape: {output.shape}\")\n", + "print(f\"Router logits shape: {router_logits.shape}\")\n", + "print(f\"Output dtype: {output.dtype}\")\n", + "print(\"✓ TE-optimized MoE block working correctly!\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Next: Weight Mapping and Integration\n", + "\n", + "To integrate with HuggingFace Mixtral models, you need to:\n", + "\n", + "1. Map weights from HF `MixtralSparseMoeBlock` to `TEMixtralSparseMoeBlock`\n", + "2. Use monkey-patching to replace HF layers during model loading\n", + "3. Implement weight loading from HF checkpoints" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + } + ], + "metadata": { + "language_info": { + "name": "python" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}