From 3ba991b447561f04fdad46f12b40ce87cf2deb0b Mon Sep 17 00:00:00 2001 From: Chen Cui Date: Fri, 30 Jan 2026 16:20:09 -0800 Subject: [PATCH 1/4] fix subchannel fp8 + sp Signed-off-by: Chen Cui --- transformer_engine/pytorch/distributed.py | 6 +++++- transformer_engine/pytorch/utils.py | 7 +++++++ 2 files changed, 12 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index f269e21b8c..fd519249a3 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1100,7 +1100,11 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: - out = torch.empty(out_shape, dtype=dtype, device=device) + # Dequantize if input is already quantized + if isinstance(inp, Float8BlockwiseQTensorStorage): + inp = inp.dequantize() + # Use dtype from actual input tensor (may differ from initial guess after dequantize) + out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) return out, None diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 47af9fabe1..3edc0d7c77 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -452,6 +452,13 @@ def assert_dim_for_all_gather( ) -> None: """Assert that tensor dimensions are supported for all-gather""" if with_all_gather: + # Float8BlockQuantizer has a fallback path in gather_along_first_dim + # that handles non-quantizable local tensors by all-gathering in + # high precision first, then quantizing the result. + from .tensor import Float8BlockQuantizer + + if isinstance(quantizer, Float8BlockQuantizer): + return assert quantizer.is_quantizable(tensor), ( "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ ) From 637ba0f8cf91805e5963c565fe14b4339693873f Mon Sep 17 00:00:00 2001 From: Tim Moon Date: Mon, 2 Feb 2026 19:59:13 +0000 Subject: [PATCH 2/4] Support sequence-parallel all-gather with small inputs Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon --- transformer_engine/pytorch/distributed.py | 21 ++++++++++++------- .../pytorch/module/layernorm_linear.py | 2 -- .../pytorch/module/layernorm_mlp.py | 2 -- transformer_engine/pytorch/module/linear.py | 2 -- transformer_engine/pytorch/utils.py | 17 --------------- 5 files changed, 13 insertions(+), 31 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index fd519249a3..2208ff720c 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1100,10 +1100,9 @@ def _start_all_gather_fp8_blockwise( # Fall back to high-precision all-gather if FP8 is not supported if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: - # Dequantize if input is already quantized - if isinstance(inp, Float8BlockwiseQTensorStorage): - inp = inp.dequantize() - # Use dtype from actual input tensor (may differ from initial guess after dequantize) + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) @@ -1342,10 +1341,13 @@ def _all_gather_nvfp4( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) @@ -1509,10 +1511,13 @@ def _all_gather_mxfp8( and quantizer is not None and not quantizer.is_quantizable(inp) ): + warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") + if isinstance(inp, QuantizedTensorStorage): + inp = inp.dequantize() # Dequantize if needed out = torch.empty( out_shape, - dtype=dtype, - device=device, + dtype=inp.dtype, + device=inp.device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) diff --git a/transformer_engine/pytorch/module/layernorm_linear.py b/transformer_engine/pytorch/module/layernorm_linear.py index 702916696b..f79bc91c0a 100644 --- a/transformer_engine/pytorch/module/layernorm_linear.py +++ b/transformer_engine/pytorch/module/layernorm_linear.py @@ -29,7 +29,6 @@ from ..quantization import FP8GlobalStateManager from ..utils import ( assert_dim_for_fp8_exec, - assert_dim_for_all_gather, cast_if_needed, clear_tensor_data, divide, @@ -158,7 +157,6 @@ def forward( inputmat = inp if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer) # Cast for native AMP nvtx_range_push(f"{nvtx_label}.norm_input_cast") diff --git a/transformer_engine/pytorch/module/layernorm_mlp.py b/transformer_engine/pytorch/module/layernorm_mlp.py index bec6744518..d9f046aa38 100644 --- a/transformer_engine/pytorch/module/layernorm_mlp.py +++ b/transformer_engine/pytorch/module/layernorm_mlp.py @@ -40,7 +40,6 @@ init_method_constant, cast_if_needed, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, clear_tensor_data, requires_grad, needs_quantized_gemm, @@ -331,7 +330,6 @@ def _forward( inputmat = inp.view((-1, in_features)) if fp8: assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight) - assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer) activation_func = _act_func( activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None diff --git a/transformer_engine/pytorch/module/linear.py b/transformer_engine/pytorch/module/linear.py index 23ad8cacb0..d7283cc047 100644 --- a/transformer_engine/pytorch/module/linear.py +++ b/transformer_engine/pytorch/module/linear.py @@ -34,7 +34,6 @@ requires_grad, needs_quantized_gemm, assert_dim_for_fp8_exec, - assert_dim_for_all_gather, nvtx_range_pop, nvtx_range_push, get_nvtx_range_context, @@ -175,7 +174,6 @@ def forward( own_quantized_input = False if fp8: assert_dim_for_fp8_exec(inputmat, weight) - assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer) if save_original_input: assert not isinstance( input_quantizer, Float8Quantizer diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 3edc0d7c77..0a74c75edd 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -447,23 +447,6 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None: ) -def assert_dim_for_all_gather( - tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer -) -> None: - """Assert that tensor dimensions are supported for all-gather""" - if with_all_gather: - # Float8BlockQuantizer has a fallback path in gather_along_first_dim - # that handles non-quantizable local tensors by all-gathering in - # high precision first, then quantizing the result. - from .tensor import Float8BlockQuantizer - - if isinstance(quantizer, Float8BlockQuantizer): - return - assert quantizer.is_quantizable(tensor), ( - "All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__ - ) - - def is_bf16_compatible() -> bool: """Replaces torch.cuda.is_bf16_compatible() with an explicit check on device compute capability to enforce sm_80 or higher. From 19fd9273e1fd8b01cd7defd006a1b6e196c6a073 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 10 Feb 2026 14:21:15 -0800 Subject: [PATCH 3/4] Fix lint error Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/utils.py | 1 - 1 file changed, 1 deletion(-) diff --git a/transformer_engine/pytorch/utils.py b/transformer_engine/pytorch/utils.py index 0a74c75edd..b1cc3be19d 100644 --- a/transformer_engine/pytorch/utils.py +++ b/transformer_engine/pytorch/utils.py @@ -12,7 +12,6 @@ import numpy as np import torch -from .quantized_tensor import Quantizer from .torch_version import torch_version from ..debug.pytorch.debug_quantization import DebugQuantizedTensor From adfe33bec3a2c7f50856afdaf04acb780d4fcdfd Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 10 Feb 2026 15:02:20 -0800 Subject: [PATCH 4/4] Keep the previous behavior with dtype Signed-off-by: Przemek Tredak --- transformer_engine/pytorch/distributed.py | 22 +++++++++++----------- 1 file changed, 11 insertions(+), 11 deletions(-) diff --git a/transformer_engine/pytorch/distributed.py b/transformer_engine/pytorch/distributed.py index 2208ff720c..2a65fa272b 100644 --- a/transformer_engine/pytorch/distributed.py +++ b/transformer_engine/pytorch/distributed.py @@ -1102,8 +1102,8 @@ def _start_all_gather_fp8_blockwise( if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): - inp = inp.dequantize() # Dequantize if needed - out = torch.empty(out_shape, dtype=inp.dtype, device=inp.device) + inp = inp.dequantize(dtype=dtype) # Dequantize if needed + out = torch.empty(out_shape, dtype=dtype, device=device) torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) out = quantizer(out) return out, None @@ -1118,7 +1118,7 @@ def _start_all_gather_fp8_blockwise( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to Float8BlockwiseQTensor." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct Float8BlockwiseQTensor output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1343,11 +1343,11 @@ def _all_gather_nvfp4( ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): - inp = inp.dequantize() # Dequantize if needed + inp = inp.dequantize(dtype=dtype) # Dequantize if needed out = torch.empty( out_shape, - dtype=inp.dtype, - device=inp.device, + dtype=dtype, + device=device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) @@ -1364,7 +1364,7 @@ def _all_gather_nvfp4( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to NVFP4." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct NVFP4 output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device) @@ -1513,11 +1513,11 @@ def _all_gather_mxfp8( ): warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") if isinstance(inp, QuantizedTensorStorage): - inp = inp.dequantize() # Dequantize if needed + inp = inp.dequantize(dtype=dtype) # Dequantize if needed out = torch.empty( out_shape, - dtype=inp.dtype, - device=inp.device, + dtype=dtype, + device=device, memory_format=torch.contiguous_format, ) torch.distributed.all_gather_into_tensor(out, inp, group=process_group) @@ -1534,7 +1534,7 @@ def _all_gather_mxfp8( "Input and quantizer do not have matching usages. " "Dequantizing and requantizing to MXFP8." ) - inp = quantizer(inp.dequantize()) + inp = quantizer(inp.dequantize(dtype=dtype)) # Construct MXFP8 output tensor out = quantizer.make_empty(out_shape, dtype=dtype, device=device)