Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions transformer_engine/pytorch/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -1100,6 +1100,9 @@ def _start_all_gather_fp8_blockwise(

# Fall back to high-precision all-gather if FP8 is not supported
if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1:
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(out_shape, dtype=dtype, device=device)
torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False)
out = quantizer(out)
Comment on lines 1102 to 1108
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Non-contiguous gather input

In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).

Expand All @@ -1115,7 +1118,7 @@ def _start_all_gather_fp8_blockwise(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to Float8BlockwiseQTensor."
)
inp = quantizer(inp.dequantize())
inp = quantizer(inp.dequantize(dtype=dtype))

# Construct Float8BlockwiseQTensor output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
Expand Down Expand Up @@ -1338,6 +1341,9 @@ def _all_gather_nvfp4(
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(
out_shape,
dtype=dtype,
Expand All @@ -1358,7 +1364,7 @@ def _all_gather_nvfp4(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to NVFP4."
)
inp = quantizer(inp.dequantize())
inp = quantizer(inp.dequantize(dtype=dtype))

# Construct NVFP4 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
Expand Down Expand Up @@ -1505,6 +1511,9 @@ def _all_gather_mxfp8(
and quantizer is not None
and not quantizer.is_quantizable(inp)
):
warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.")
if isinstance(inp, QuantizedTensorStorage):
inp = inp.dequantize(dtype=dtype) # Dequantize if needed
out = torch.empty(
out_shape,
dtype=dtype,
Expand All @@ -1525,7 +1534,7 @@ def _all_gather_mxfp8(
"Input and quantizer do not have matching usages. "
"Dequantizing and requantizing to MXFP8."
)
inp = quantizer(inp.dequantize())
inp = quantizer(inp.dequantize(dtype=dtype))

# Construct MXFP8 output tensor
out = quantizer.make_empty(out_shape, dtype=dtype, device=device)
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/layernorm_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
from ..quantization import FP8GlobalStateManager
from ..utils import (
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
cast_if_needed,
clear_tensor_data,
divide,
Expand Down Expand Up @@ -158,7 +157,6 @@ def forward(
inputmat = inp
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather, input_quantizer)

# Cast for native AMP
nvtx_range_push(f"{nvtx_label}.norm_input_cast")
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/layernorm_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,6 @@
init_method_constant,
cast_if_needed,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
clear_tensor_data,
requires_grad,
needs_quantized_gemm,
Expand Down Expand Up @@ -331,7 +330,6 @@ def _forward(
inputmat = inp.view((-1, in_features))
if fp8:
assert_dim_for_fp8_exec(inputmat, fc1_weight, fc2_weight)
assert_dim_for_all_gather(inputmat, sequence_parallel, fc1_input_quantizer)

activation_func = _act_func(
activation, FP8GlobalStateManager.get_fp8_recipe() if fp8 else None
Expand Down
2 changes: 0 additions & 2 deletions transformer_engine/pytorch/module/linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,6 @@
requires_grad,
needs_quantized_gemm,
assert_dim_for_fp8_exec,
assert_dim_for_all_gather,
nvtx_range_pop,
nvtx_range_push,
get_nvtx_range_context,
Expand Down Expand Up @@ -175,7 +174,6 @@ def forward(
own_quantized_input = False
if fp8:
assert_dim_for_fp8_exec(inputmat, weight)
assert_dim_for_all_gather(inputmat, with_input_all_gather_nccl, input_quantizer)
if save_original_input:
assert not isinstance(
input_quantizer, Float8Quantizer
Expand Down
11 changes: 0 additions & 11 deletions transformer_engine/pytorch/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,6 @@
import numpy as np
import torch

from .quantized_tensor import Quantizer
from .torch_version import torch_version
from ..debug.pytorch.debug_quantization import DebugQuantizedTensor

Expand Down Expand Up @@ -447,16 +446,6 @@ def assert_dim_for_fp8_exec(*tensors: List[torch.Tensor]) -> None:
)


def assert_dim_for_all_gather(
tensor: torch.Tensor, with_all_gather: bool, quantizer: Quantizer
) -> None:
"""Assert that tensor dimensions are supported for all-gather"""
if with_all_gather:
assert quantizer.is_quantizable(tensor), (
"All-gather requires quantizable tensor for quantizer " + quantizer.__class__.__name__
)


def is_bf16_compatible() -> bool:
"""Replaces torch.cuda.is_bf16_compatible() with an explicit
check on device compute capability to enforce sm_80 or higher.
Expand Down
Loading