-
Notifications
You must be signed in to change notification settings - Fork 636
[Common] MOE Split dBias #2674
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
[Common] MOE Split dBias #2674
Conversation
dad4c68 to
aff53ff
Compare
4881d1b to
ac81c85
Compare
Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryThis PR changes the grouped MXFP8 quantize+dbias APIs to output a grouped dbias tensor (per-tensor dbias values) instead of a single flat tensor, and wires that through the grouped activation backprop entrypoints. Internally, it adds a grouped dbias reduction kernel to reduce per-tile partial dbias workspace into a Key integration points:
Main issue found: Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant CastAPI as cast.h/C API
participant CastCU as common/cast/cast.cu
participant Dispatch as common/cast/dispatch/quantize.cuh
participant MX as common/cast/mxfp8/group_quantize_mxfp8.cuh
participant Reduce as common/cast/core/common.cuh
Caller->>CastAPI: nvte_group_quantize_dbias(grad_group, out_group, dbias_group, workspace)
CastAPI->>CastCU: nvte_group_quantize_dbias(...)
CastCU->>Dispatch: group_quantize_bwd_helper<IS_DBIAS=true>(grad, nullptr, out, dbias_group, workspace)
Dispatch->>MX: mxfp8::group_quantize<IS_DBIAS=true>(GroupedTensor*, ..., GroupedTensor* dbias)
MX->>MX: group_quantize_mxfp8_kernel writes partial dbias into workspace (float)
MX->>Reduce: grouped_reduce_dbias<IType>(shape_rep, num_tensors, logical dims, offsets/dims, dbias_group, workspace)
Reduce-->>Caller: dbias_group contains per-tensor dbias rows
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
9 files reviewed, 1 comment
| if (!is_single_tensor) { | ||
| NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, | ||
| "Number of tensors in a group is larger than " | ||
| "the MAX number of supported descriptors (64)."); | ||
| // Only full tiles supported | ||
| NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, | ||
| "Last dimension of a grouped tensor should be divisible by 128."); | ||
| blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); | ||
| } | ||
| const dim3 grid(blocks); | ||
|
|
||
| NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); | ||
| const dim3 grid(elts_total / ELTS_PER_CHUNK); | ||
| const size_t block_size = THREADS_PER_CHUNK; |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Regression: tail tiles disallowed
group_quantize now enforces elts_total % ELTS_PER_CHUNK == 0 and sets grid(elts_total / ELTS_PER_CHUNK) for all shapes, including the is_single_tensor case (shape_rep == SAME_BOTH_DIMS || VARYING_FIRST_DIM). In the previous implementation, the is_single_tensor path used DIVUP(first_logical_dim, 128)/DIVUP(last_logical_dim, 128), so tail tiles were accepted.
If callers previously passed non-128-multiple M/K for constant-last-dim groups, this change will now hard-fail. If dropping tail support is intentional, it should be documented and covered by tests; otherwise the grid/check logic likely needs to preserve the old DIVUP behavior for is_single_tensor.
|
/te-ci |
Description
This PR adds a new kernel that computes
dbiasseparately for each tensor in a group and outputs a groupeddbiastensor containing per-tensordbiasvalues.Fixes # (issue)
Type of change
Changes
Checklist: