-
Notifications
You must be signed in to change notification settings - Fork 635
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel #2630
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Common] Fuse pre-swizzling into grouped MXFP8 quantization kernel #2630
Conversation
Greptile OverviewGreptile SummaryThis PR adds optional pre-swizzling support to the grouped MXFP8 quantization kernel, allowing scaling factors to be stored in the format expected by GEMM operations. The implementation adds a new template parameter Key Changes:
Confidence Score: 5/5
Important Files Changed
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
bf07d9d to
ed61ff7
Compare
for more information, see https://pre-commit.ci
|
/te-ci |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, no comments
ksivaman
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
2 files reviewed, 1 comment
| size_t scale_idx = 0; | ||
| if constexpr (WITH_GEMM_SWIZZLED_SCALES) { | ||
| scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, | ||
| DIVUP(rows, static_cast<size_t>(128))); | ||
| } else { | ||
| scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; | ||
| } | ||
| scales_colwise[scale_idx] = biased_exponent; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Type inconsistency: scale_idx is declared as size_t in the colwise path but as int in the rowwise path (line 615). Should use consistent type (size_t) in both paths.
| size_t scale_idx = 0; | |
| if constexpr (WITH_GEMM_SWIZZLED_SCALES) { | |
| scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, | |
| DIVUP(rows, static_cast<size_t>(128))); | |
| } else { | |
| scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; | |
| } | |
| scales_colwise[scale_idx] = biased_exponent; | |
| size_t scale_idx = 0; | |
| if constexpr (WITH_GEMM_SWIZZLED_SCALES) { | |
| scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, | |
| DIVUP(rows, static_cast<size_t>(128))); | |
| } else { | |
| scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; | |
| } |
Description
This PR fuses pre-swizzling into the grouped MXFP8 quantization kernel so that scaling factors are stored in the format expected by GEMM.
Type of change
Changes
GroupedTensorChecklist: