-
Notifications
You must be signed in to change notification settings - Fork 635
[C] NVFP4 quantization for GroupedTensor
#2655
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com> Signed-off-by: Zhongbo Zhu <zhongboz@nvidia.com> Co-authored-by: Zhongbo Zhu <zhongboz@nvidia.com>
Greptile OverviewGreptile SummaryThis PR adds NVFP4 quantization support for Major changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Client
participant API as C API Layer
participant Kernel as CUDA Kernel
participant GPU as GPU Memory
Client->>API: nvte_group_hadamard_transform_cast_fusion_graph_safe()
API->>API: Convert NVTEGroupedTensor to GroupedTensor
API->>API: Convert NVTEQuantizationConfig to QuantizationConfig
API->>API: Validate quant_workspace (>=4 bytes)
alt Stochastic Rounding Enabled
API->>API: Validate RNG state tensor
end
API->>Kernel: Launch group_hadamard_transform_cast_fusion_graph_safe()
Kernel->>GPU: TMA load input tensors
Kernel->>Kernel: Determine tensor ID from offset (binary search)
Kernel->>Kernel: Apply row-wise quantization
Kernel->>Kernel: Apply Hadamard transform
Kernel->>Kernel: Apply column-wise quantization to NVFP4
Kernel->>Kernel: Compute scaling factors (FP8 E4M3)
Kernel->>GPU: Write quantized output + scaling factors
Kernel-->>API: Return
API-->>Client: Return
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
3 files reviewed, 3 comments
| // TODO(zhongbo): double check the logic here | ||
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | ||
| packed_N, M, offsets); | ||
|
|
||
| // Determine quantization scale factor layouts/output splits for this group | ||
| TSFDLayout sfd_layout; | ||
| int cur_N = static_cast<int>(first_dims[group_idx]); | ||
| if constexpr (kEnableSwizzleSFOutput) { | ||
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | ||
| } else { | ||
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | ||
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | ||
| } | ||
| // Build output tensors for columns and their quant scales | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
multiple TODO comments requesting logic verification in critical group index calculation and tensor layout code - verify group_idx calculation and tensor layout logic are correct before merging
| // TODO(zhongbo): double check the logic here | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| // TODO(zhongbo): double check the logic here | |
| // Determine the current tensor group index based on tile offset | |
| int group_idx = get_current_tensor_id(shape_rep, num_tensors, | |
| (scheduler.tile_n_base() * size<1>(epilogue_tiler)) * M, | |
| packed_N, M, offsets); | |
| // Determine quantization scale factor layouts/output splits for this group | |
| TSFDLayout sfd_layout; | |
| int cur_N = static_cast<int>(first_dims[group_idx]); | |
| if constexpr (kEnableSwizzleSFOutput) { | |
| sfd_layout = tile_to_shape(SwizzledSFDLayoutAtom{}, make_shape(M, cur_N), Step<_2, _1>{}); | |
| } else { | |
| sfd_layout = make_layout(make_shape(M, make_shape(Int<SFVecSize>{}, cur_N / SFVecSize)), | |
| make_stride(cur_N / SFVecSize, make_stride(_0{}, _1{}))); | |
| } | |
| // Build output tensors for columns and their quant scales | |
| Tensor mD = make_tensor(cute::subbyte_iterator<TD>(reinterpret_cast<TD *>( | |
| reinterpret_cast<char *>(QA_COLWISE) + offsets[group_idx] / 2)), | |
| make_shape(M, cur_N), DStride{}); // (M,packed_N) |
| // TODO(zhongbo): double check the logic here | ||
| int cur_group_idx = get_current_tensor_id(shape_rep, num_tensors, | ||
| global_tile_n_offset * M, packed_N, M, offsets); | ||
|
|
||
| if (cur_group_idx != group_idx) { | ||
| group_idx = cur_group_idx; | ||
| c_global_amax_val = shared_storage.global_d_amax[group_idx]; | ||
| // update amax | ||
| global_encode_scale = c_global_amax_val > 0.0f | ||
| ? cutlass::minimum_with_nan_propagation<float>{}( | ||
| (fp8_max * fp4_max) / c_global_amax_val, | ||
| cutlass::platform::numeric_limits<float>::max()) | ||
| : 1.0f; | ||
| global_decode_scale = 1.0f / global_encode_scale; | ||
| if constexpr (kUseFastMath) { | ||
| global_encode_scale_multiplier = global_encode_scale * fp4_max_inv; | ||
| } | ||
| // TODO(zhongbo): double check the logic here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
more TODO comments in epilogue loop - verify group index recalculation and amax scaling logic
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| using transformer_engine::detail::ShapeRepresentation; | ||
|
|
||
| void *input_base_ptr = reinterpret_cast<void *>(input->data.dptr); | ||
| // TODO(zhongbo): add input sanity checks here |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
add input sanity checks as noted in TODO
|
Fixes #2510 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 2 comments
| QuantizationConfig quant_config_cpp; | ||
| if (quant_config != nullptr) { | ||
| quant_config_cpp = *reinterpret_cast<QuantizationConfig *>(quant_config); | ||
| } | ||
|
|
||
| if (input_tensor->num_tensors == 0) { | ||
| return; | ||
| } | ||
|
|
||
| // Call the multi-tensor Hadamard transform amax implementation. | ||
| group_hadamard_transform_cast_fusion_graph_safe( | ||
| input_tensor, output_tensor, *convertNVTETensorCheck(hadamard_matrix), quant_config_cpp, | ||
| *quant_workspace_tensor, stream); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Uninitialized quant_config used
QuantizationConfig quant_config_cpp; is left uninitialized when quant_config == nullptr, but it’s still passed into group_hadamard_transform_cast_fusion_graph_safe(...) (line 1511). That’s undefined behavior and can lead to garbage config values being read inside the kernel launch path. Consider default-initializing (QuantizationConfig quant_config_cpp{};) or returning an error when quant_config is required.
Additional Comments (1)
|
|
/te-ci |
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
zhongbozhu
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Conditional approve. More changes to come for #2600
Description
Pieces taken from #2600.
Type of change
Changes
Checklist: