Skip to content

Conversation

@Oleg-Goncharov
Copy link
Collaborator

Description

This PR adds a new kernel that computes dbias separately for each tensor in a group and outputs a grouped dbias tensor containing per-tensor dbias values.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added the grouped dbias kernel

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Oleg Goncharov <ogoncharov@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Feb 11, 2026

Greptile Overview

Greptile Summary

This PR changes the grouped MXFP8 quantize+dbias APIs to output a grouped dbias tensor (per-tensor dbias values) instead of a single flat tensor, and wires that through the grouped activation backprop entrypoints. Internally, it adds a grouped dbias reduction kernel to reduce per-tile partial dbias workspace into a [num_tensors, K] grouped output, and updates the C++ test to validate grouped dbias.

Key integration points:

  • Public C API signatures in transformer_engine/common/include/transformer_engine/cast.h now take NVTEGroupedTensor dbias for grouped dbias functions.
  • Dispatch helpers in common/cast/dispatch/quantize.cuh treat dbias as a nullable grouped tensor and pass it to mxfp8::group_quantize.
  • The MXFP8 grouped kernel continues to compute partial dbias into a float workspace, then calls the new grouped reducer in common/cast/core/common.cuh to produce the grouped dbias output.

Main issue found: group_quantize_mxfp8.cuh now enforces full-tile-only grid sizing even for the is_single_tensor path, which previously allowed tail tiles, causing a functional regression for non-128-multiple shapes.

Confidence Score: 3/5

  • This PR is close to mergeable but has a functional regression in grouped MXFP8 quantization shape support that should be resolved first.
  • Core grouped dbias plumbing and test updates look consistent, and the new grouped reducer is only used under constant-last-dim dbias mode. However, the updated grid/check logic in group_quantize_mxfp8 now rejects shapes that were previously accepted (tail tiles in the is_single_tensor path), which can break existing callers.
  • transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh

Important Files Changed

Filename Overview
tests/cpp/operator/test_cast_mxfp8_grouped.cu Updates grouped MXFP8 cast tests to use grouped dbias output (NVTEGroupedTensor) and validate concatenated per-tensor dbias.
transformer_engine/common/activation/gelu.cu Updates grouped dgelu/dqgelu dbias plumbing to pass NVTEGroupedTensor dbias through group_quantize_bwd_helper.
transformer_engine/common/activation/relu.cu Updates grouped drelu/dsrelu dbias plumbing to use NVTEGroupedTensor dbias in group quantize helpers.
transformer_engine/common/activation/swiglu.cu Updates grouped dsilu dbias plumbing to use NVTEGroupedTensor dbias in group quantize helpers.
transformer_engine/common/cast/cast.cu Changes nvte_group_quantize_dbias signature to accept grouped dbias output; dispatches to updated group_quantize_bwd_helper.
transformer_engine/common/cast/core/common.cuh Adds ShapeRepresentation enum and a new grouped_reduce_dbias kernel to reduce per-tensor dbias from workspace into grouped output.
transformer_engine/common/cast/dispatch/quantize.cuh Updates group quantize helper templates to treat dbias as NVTEGroupedTensor (nullable) and convert via convertNVTEGroupedTensor.
transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh Implements grouped dbias reduction and switches to full-tile-only grid sizing; introduces functional regression by disallowing tail tiles for is_single_tensor.
transformer_engine/common/include/transformer_engine/cast.h Updates public C API signatures for grouped dbias functions to return NVTEGroupedTensor dbias instead of NVTETensor.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant CastAPI as cast.h/C API
    participant CastCU as common/cast/cast.cu
    participant Dispatch as common/cast/dispatch/quantize.cuh
    participant MX as common/cast/mxfp8/group_quantize_mxfp8.cuh
    participant Reduce as common/cast/core/common.cuh

    Caller->>CastAPI: nvte_group_quantize_dbias(grad_group, out_group, dbias_group, workspace)
    CastAPI->>CastCU: nvte_group_quantize_dbias(...)
    CastCU->>Dispatch: group_quantize_bwd_helper<IS_DBIAS=true>(grad, nullptr, out, dbias_group, workspace)
    Dispatch->>MX: mxfp8::group_quantize<IS_DBIAS=true>(GroupedTensor*, ..., GroupedTensor* dbias)
    MX->>MX: group_quantize_mxfp8_kernel writes partial dbias into workspace (float)
    MX->>Reduce: grouped_reduce_dbias<IType>(shape_rep, num_tensors, logical dims, offsets/dims, dbias_group, workspace)
    Reduce-->>Caller: dbias_group contains per-tensor dbias rows
Loading

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

9 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +787 to 795
if (!is_single_tensor) {
NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS,
"Number of tensors in a group is larger than "
"the MAX number of supported descriptors (64).");
// Only full tiles supported
NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0,
"Last dimension of a grouped tensor should be divisible by 128.");
blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X);
}
const dim3 grid(blocks);

NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported.");
const dim3 grid(elts_total / ELTS_PER_CHUNK);
const size_t block_size = THREADS_PER_CHUNK;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Regression: tail tiles disallowed

group_quantize now enforces elts_total % ELTS_PER_CHUNK == 0 and sets grid(elts_total / ELTS_PER_CHUNK) for all shapes, including the is_single_tensor case (shape_rep == SAME_BOTH_DIMS || VARYING_FIRST_DIM). In the previous implementation, the is_single_tensor path used DIVUP(first_logical_dim, 128)/DIVUP(last_logical_dim, 128), so tail tiles were accepted.

If callers previously passed non-128-multiple M/K for constant-last-dim groups, this change will now hard-fail. If dropping tail support is intentional, it should be documented and covered by tests; otherwise the grid/check logic likely needs to preserve the old DIVUP behavior for is_single_tensor.

@Oleg-Goncharov
Copy link
Collaborator Author

/te-ci

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant