-
Notifications
You must be signed in to change notification settings - Fork 634
Fix FP8 block scaling with sequence parallel #2637
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Chen Cui <chcui@nvidia.com>
Greptile OverviewGreptile SummaryThis PR removes a pre-check that required sequence-parallel all-gather inputs to be quantizable, and instead relies on The main functional changes are concentrated in Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant M as Module (Linear/LNLinear/LNMLP)
participant Q as *Quantizer
participant D as distributed.gather_along_first_dim
participant AG as torch.distributed.all_gather_into_tensor
M->>Q: (optional) quantize local activation
M->>D: gather_along_first_dim(inp, tp_group, quantizer)
alt Float8BlockQuantizer and quantizable
D->>D: _start_all_gather_fp8_blockwise
D->>AG: all_gather_into_tensor(FP8 blocks + scales)
D-->>M: Float8BlockwiseQTensorStorage (+ async handle)
else Not quantizable / unsupported dims
D->>D: high-precision fallback
D->>Q: (if QuantizedTensorStorage) dequantize(dtype)
D->>AG: all_gather_into_tensor(out_fp, inp_fp)
D->>Q: quantize gathered output
D-->>M: Quantized output tensor
end
|
This comment was marked as outdated.
This comment was marked as outdated.
Perform all-gather in high-precision if the input tensor is too small to quantize. Signed-off-by: Tim Moon <tmoon@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
timmoon10
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've tweaked the module logic so that it no longer checks that the local tensor can be quantized. In addition to FP8 block scaling, MXFP8 and NVFP4 will also support high-precision all-gathers. That said, this may not be fully robust for LayerNormLinear and LayerNormMLP, since the norm kernel may still expect to output to a quantized local tensor.
This comment was marked as outdated.
This comment was marked as outdated.
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 3 comments
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The dequantize() method defaults to dtype=torch.float32. Consider whether this is always appropriate for the fallback path, especially when the original tensor might have been in a different precision (e.g., bfloat16).
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
| ): | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize() # Dequantize if needed |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Same consideration as in _start_all_gather_fp8_blockwise: the dequantize() method defaults to dtype=torch.float32, which may not match the original tensor's precision.
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Przemek Tredak <ptredak@nvidia.com>
|
/te-ci pytorch L1 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
| if not quantizer.is_quantizable(inp) or quantizer.block_scaling_dim != 1: | ||
| warnings.warn("Cannot quantize input tensor. Performing all-gather in high precision.") | ||
| if isinstance(inp, QuantizedTensorStorage): | ||
| inp = inp.dequantize(dtype=dtype) # Dequantize if needed | ||
| out = torch.empty(out_shape, dtype=dtype, device=device) | ||
| torch.distributed.all_gather_into_tensor(out, inp, group=process_group, async_op=False) | ||
| out = quantizer(out) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Non-contiguous gather input
In the new high-precision fallback (if not quantizer.is_quantizable(inp) ...), all_gather_into_tensor(out, inp, ...) passes inp directly. Elsewhere in this same module the plain-tensor path uses inp.contiguous() (distributed.py:1737-1742) and the FP8 path uses _data.contiguous() (distributed.py:1031-1035), which strongly suggests the collective expects contiguous inputs. If inp is a non-contiguous view (common after transpose/slicing), this fallback can raise at runtime. This same issue also appears in the NVFP4 and MXFP8 high-precision fallbacks (distributed.py:1353 and :1523).
Description
Problem
Using Float8BlockQuantizer with sequence parallel fails with
AssertionError: All-gather requires quantizable tensor for quantizer Float8BlockQuantizerwhen local tensor dimensions aren't divisible by 128.Solution
Skip the assert_dim_for_all_gather check for Float8BlockQuantizer since gather_along_first_dim already has a fallback path
Fix the fallback in _start_all_gather_fp8_blockwise to handle already-quantized inputs by dequantizing before high-precision all-gather
###Note
The fallback path (high-precision all-gather → quantize) may increase the communication overhead.
Verification
The code change does not alter convergence behavior

When SP is True, the previous code did not run. When SP is False, this change doesn't affect anything.

Type of change
Changes
Please list the changes introduced in this PR:
Checklist: