From d33114b83681424a8314db4302dd961221822f38 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Wed, 11 Feb 2026 17:29:35 +0000 Subject: [PATCH 1/5] Implemented the kernel with split dbias Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 116 +++++++++--------- transformer_engine/common/activation/gelu.cu | 8 +- transformer_engine/common/activation/relu.cu | 8 +- .../common/activation/swiglu.cu | 4 +- transformer_engine/common/cast/cast.cu | 2 +- .../common/cast/core/common.cuh | 94 ++++++++++++++ .../common/cast/dispatch/quantize.cuh | 11 +- .../cast/mxfp8/group_quantize_mxfp8.cuh | 57 ++++----- .../common/include/transformer_engine/cast.h | 12 +- 9 files changed, 200 insertions(+), 112 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 8b084ca452..3f246b19aa 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,8 +58,7 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise, - const bool is_single_tensor) + const size_t scales_stride_colwise) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -169,10 +168,8 @@ void compute_ref(const ProcessingMethod processing_method, } } - if (is_single_tensor) { - for (size_t j = 0; j < cols; ++j) { - output_dbias[j] = static_cast(output_dbias_fp32[j]); - } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); } } @@ -250,12 +247,16 @@ void performTest(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT); + const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; size_t elts_num = 0; size_t rowwise_sfs_num = 0; size_t colwise_sfs_num = 0; + size_t sum_of_last_dims = 0; std::vector rowwise_scales_first_dim(num_tensors, 0); std::vector rowwise_scales_last_dim(num_tensors, 0); @@ -263,6 +264,7 @@ void performTest(const ProcessingMethod processing_method, std::vector colwise_scales_first_dim(num_tensors, 0); std::vector colwise_scales_last_dim(num_tensors, 0); std::vector colwise_scales_offset(num_tensors + 1, 0); + std::vector dbias_offsets(num_tensors + 1, 0); for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; @@ -285,13 +287,13 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; - + sum_of_last_dims += K; + rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; + dbias_offsets[t+1] = sum_of_last_dims; } - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); - std::vector scales_rowwise_shape = {rowwise_sfs_num}; std::vector scales_colwise_shape = {colwise_sfs_num}; @@ -311,7 +313,7 @@ void performTest(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? colwise_sfs_num : 0); - std::vector ref_output_dbias(is_single_tensor ? cols : 0); + std::vector ref_output_dbias(sum_of_last_dims, static_cast(0.0f)); for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); @@ -336,6 +338,7 @@ void performTest(const ProcessingMethod processing_method, const size_t in_data_size = elts_num * sizeof(InputType); const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType); const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0); const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0); @@ -345,6 +348,7 @@ void performTest(const ProcessingMethod processing_method, InputType* grad_data_d; InputType* in_data_d; + InputType* dbias_out_data_d; OutputType* out_data_rowwise_d; OutputType* out_data_colwise_d; fp8e8m0* out_scales_rowwise_d; @@ -366,6 +370,10 @@ void performTest(const ProcessingMethod processing_method, cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + + std::vector dbias_logical_shape_vec= {num_tensors, cols}; + NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), + dbias_logical_shape_vec.size()); NVTEShape first_dims_shape_; NVTEShape last_dims_shape_; @@ -382,11 +390,12 @@ void performTest(const ProcessingMethod processing_method, NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_); NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor); + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; @@ -433,52 +442,40 @@ void performTest(const ProcessingMethod processing_method, nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); } - Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + if (compute_dbias) { + cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); + cudaMemset(dbias_out_data_d, 0, dbias_data_size); + NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; + nvte_set_grouped_tensor_param(&output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); + } // Reference (CPU) - if (is_single_tensor) { - - const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32); - const size_t unpadded_colwise_blocks_X = cols; + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; - const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); - const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); + const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; + const size_t scales_stride_colwise = colwise_scales_last_dim[t]; + const size_t data_offset = offsets_h[t]; + const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; + const size_t colwise_sfs_offset = colwise_scales_offset[t]; + const size_t dbias_offset = dbias_offsets[t]; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; + InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset; compute_ref( - processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), - out_data_rowwise_ref.data(), out_data_colwise_ref.data(), - out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), - ref_output_dbias.data(), rows, cols, + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } else { - for (size_t t = 0; t < num_tensors; ++t) { - const size_t M = first_dims_h[t]; - const size_t K = last_dims_h[t]; - - const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; - const size_t scales_stride_colwise = colwise_scales_last_dim[t]; - const size_t data_offset = offsets_h[t]; - const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; - const size_t colwise_sfs_offset = colwise_scales_offset[t]; - - const InputType* const grad_ptr = grad_data.data() + data_offset; - const InputType* const in_ptr = in_data.data() + data_offset; - OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; - OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; - fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; - fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; - - compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, - out_data_rowwise_ptr, out_data_colwise_ptr, - out_scales_rowwise_ptr, out_scales_colwise_ptr, - ref_output_dbias.data(), M, K, - scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } + scales_stride_colwise); } // GPU @@ -489,9 +486,9 @@ void performTest(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_DBIAS_DACT: { @@ -502,10 +499,10 @@ void performTest(const ProcessingMethod processing_method, else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -566,9 +563,10 @@ void performTest(const ProcessingMethod processing_method, out_data_colwise_h.data(), rows, cols, false, mismatches_elts); } - if (processing_method == ProcessingMethod::CAST_DBIAS - || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - { + if (compute_dbias) { + Tensor output_dbias("output_dbias", std::vector{ sum_of_last_dims }, itype); + cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice); + auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -581,6 +579,7 @@ void performTest(const ProcessingMethod processing_method, cudaFree(grad_data_d); cudaFree(in_data_d); + cudaFree(dbias_out_data_d); cudaFree(first_dims_d); cudaFree(last_dims_d); cudaFree(offsets_d); @@ -628,7 +627,6 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 2, 384,160, 128,256}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index d209ea8d47..ea864813bf 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index b6f758caf6..fc9122b7ec 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_drelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_drelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsrelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 77d5b6867f..12478af4cf 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsilu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 582172a88e..f00b34b3d9 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d } void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 0997b01f7e..24a7e7fa79 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -22,6 +22,14 @@ namespace transformer_engine { namespace dispatch { namespace common { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -78,6 +86,61 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } stg_vec.store_to(thread_out_base); } + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const offsets_ptr, + const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, + OType *const dbias_output, + const float *dbias_partial, + const size_t chunk_dim_Y) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t tensor_id = blockIdx.y; + const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + + const size_t rows = tensor_rows / chunk_dim_Y; + const size_t cols = last_logical_dim; + + const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec; + OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} } // namespace kernel template @@ -96,6 +159,37 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, + const size_t num_tensors, + const size_t first_logical_dim, + const size_t last_logical_dim, + const int64_t *const data_tensor_offsets_ptr, + const int64_t *const data_tensor_first_dims_ptr, + const int64_t *const data_tensor_last_dims_ptr, + GroupedTensor *dbias, + const float *workspace_ptr, + const size_t chunk_dim_Y, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec); + const size_t blocks_Y = num_tensors; + const dim3 grid(blocks_X, blocks_Y); + + group_reduce_dbias_kernel + <<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace common } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index b83df1dedf..1bb6552349 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -380,13 +380,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config @@ -417,8 +417,9 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor template void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); @@ -426,7 +427,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a29a09836e..a5a46d2c6e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -28,19 +28,14 @@ namespace dispatch { namespace mxfp8 { namespace group_quantize_kernel { +using namespace dispatch::common; + constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -144,11 +139,14 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if constexpr (is_blackwell) { const size_t global_stride_bytes = global_dim_X * data_type_size_bytes; if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { - NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned."); + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); } if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } + if (global_dim_X % CHUNK_DIM_X != 0) { + NVTE_DEVICE_ERROR("The grouped tensor must be divisible by 128x128 tiles without a tail tile."); + } asm volatile( "{\n\t" @@ -749,8 +747,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, - const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { + const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, + Tensor *workspace, cudaStream_t stream) { using namespace group_quantize_kernel; checkCuDriverContext(stream); @@ -801,22 +799,14 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks = 0; - - if (is_single_tensor) { - const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); - blocks = blocks_Y * blocks_X; - } else { - NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + if (!is_single_tensor) { + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - // Only full tiles supported - NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, - "Last dimension of a grouped tensor should be divisible by 128."); - blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } - const dim3 grid(blocks); + + NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); + const dim3 grid(elts_total / ELTS_PER_CHUNK); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -845,18 +835,20 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); - NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + std::vector expected_shape_dbias_tensor = {num_tensors, last_logical_dim}; + NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); + + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; workspace->data.dtype = DType::kFloat32; return; } @@ -972,9 +964,12 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, + offsets_ptr, first_dims_ptr, last_dims_ptr, + dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 04712d3003..88c483d400 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -161,7 +161,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d * \param[in] stream CUDA stream used for the operation. */ void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. @@ -207,7 +207,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -253,7 +253,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -299,7 +299,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -345,7 +345,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -391,7 +391,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, From 53e99c3b434e9d6c852d4170dc5d9b9f55bd58b4 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Feb 2026 17:30:57 +0000 Subject: [PATCH 2/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 4 +- .../common/cast/core/common.cuh | 46 ++++++++----------- .../cast/mxfp8/group_quantize_mxfp8.cuh | 8 ++-- .../common/include/transformer_engine/cast.h | 15 ++++-- 4 files changed, 34 insertions(+), 39 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 3f246b19aa..889de78a52 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -288,7 +288,7 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; sum_of_last_dims += K; - + rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; dbias_offsets[t+1] = sum_of_last_dims; @@ -370,7 +370,7 @@ void performTest(const ProcessingMethod processing_method, cudaMemcpy(offsets_d, offsets_h.data(), offsets_size, cudaMemcpyHostToDevice); NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); - + std::vector dbias_logical_shape_vec= {num_tensors, cols}; NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), dbias_logical_shape_vec.size()); diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 24a7e7fa79..a4e033939b 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -89,30 +89,25 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) template __global__ void __launch_bounds__(THREADS_PER_BLOCK) - group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, - const int64_t *const offsets_ptr, - const int64_t *const first_dims_ptr, - const int64_t *const last_dims_ptr, - OType *const dbias_output, - const float *dbias_partial, - const size_t chunk_dim_Y) { + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, OType *const dbias_output, + const float *dbias_partial, const size_t chunk_dim_Y) { using ComputeVec = Vec; using OutputVec = Vec; const size_t tensor_id = blockIdx.y; const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (first_logical_dim / num_tensors) - : first_dims_ptr[tensor_id]; - + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + const size_t rows = tensor_rows / chunk_dim_Y; const size_t cols = last_logical_dim; const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) - ? (tensor_id * (tensor_rows / chunk_dim_Y)) - : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; @@ -160,16 +155,12 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, } template -void grouped_reduce_dbias(const ShapeRepresentation shape_rep, - const size_t num_tensors, - const size_t first_logical_dim, - const size_t last_logical_dim, +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, const int64_t *const data_tensor_offsets_ptr, const int64_t *const data_tensor_first_dims_ptr, - const int64_t *const data_tensor_last_dims_ptr, - GroupedTensor *dbias, - const float *workspace_ptr, - const size_t chunk_dim_Y, + const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias, + const float *workspace_ptr, const size_t chunk_dim_Y, cudaStream_t stream) { using namespace kernel; constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 @@ -181,11 +172,10 @@ void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t blocks_Y = num_tensors; const dim3 grid(blocks_X, blocks_Y); - group_reduce_dbias_kernel - <<>>( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - data_tensor_offsets_ptr, data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, - reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + group_reduce_dbias_kernel<<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr, + data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a5a46d2c6e..14e4024c8d 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -145,7 +145,8 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } if (global_dim_X % CHUNK_DIM_X != 0) { - NVTE_DEVICE_ERROR("The grouped tensor must be divisible by 128x128 tiles without a tail tile."); + NVTE_DEVICE_ERROR( + "The grouped tensor must be divisible by 128x128 tiles without a tail tile."); } asm volatile( @@ -966,9 +967,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations if constexpr (IS_DBIAS) { common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, - dbias, workspace_ptr, CHUNK_DIM_Y, stream); + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 88c483d400..95d01fd8bf 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -207,7 +207,8 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -253,7 +254,8 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -299,7 +301,8 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -345,7 +348,8 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -391,7 +395,8 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING, From 43a7a44eb6da77330e227003a3f918ddf207a42a Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Thu, 12 Feb 2026 15:20:33 +0000 Subject: [PATCH 3/5] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../common/cast/mxfp8/group_quantize_mxfp8.cuh | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 14e4024c8d..6e314bfdf0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -965,11 +965,11 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - if constexpr (IS_DBIAS) { - common::grouped_reduce_dbias( - shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, - first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); - } + if constexpr (IS_DBIAS) { + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); + } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) ); // NOLINT(*) From 3ca3f6bf5aa4053b254f28bb7b09e8346da823ab Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Feb 2026 14:12:36 +0000 Subject: [PATCH 4/5] Relaxed constraints on the last dimension Signed-off-by: Oleg Goncharov --- tests/cpp/operator/test_cast_mxfp8_grouped.cu | 20 ++++++++++++---- .../cast/mxfp8/group_quantize_mxfp8.cuh | 23 +++++++++++-------- 2 files changed, 29 insertions(+), 14 deletions(-) diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 889de78a52..6d0dde4e83 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -627,6 +627,8 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, + {VARYING_FIRST_DIM, 3, 1024,144, 128,384,512}, + {VARYING_FIRST_DIM, 4, 1536,160, 128,384,512,512}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, @@ -692,23 +694,31 @@ TEST_P(GroupedFusedCastMXFP8TestSuite, Test) { } } offsets[t+1] = offsets[t] + first_dims[t] * last_dims[t]; - // Skips tests if tensor shape is not as required by the kernel - if ((first_dims[t] % 128 != 0) || (last_dims[t] % 32 != 0)) { + // Skip tests when the tensor shape is incompatible with the kernel. + // The TMA engine requires strides to be 16-byte aligned. + if ((first_dims[t] % 128 != 0) || (last_dims[t] % 16 != 0)) { + GTEST_SKIP(); + } + // If a grouped tensor has a varying last dimension, it must be a multiple of 128. + // Otherwise, computing the grid size adds runtime overhead in the non-persistent kernel, + // since the relevant tensor metadata resides in device memory. + constexpr size_t CHUNK_DIM_X = 128; + if (!is_single_tensor && (last_dims[t] % CHUNK_DIM_X != 0)) { GTEST_SKIP(); } } - // Skips DBias tests if last dimension of tensors variates + // Skip dBias tests when tensors in the group have different last dimensions. if ((processing_method == ProcessingMethod::CAST_DBIAS || processing_method == ProcessingMethod::CAST_DBIAS_DACT) && !is_single_tensor) { GTEST_SKIP(); } - // Skips non Act tests if the Activation type is not an identity + // Skip non-activation tests when the activation type is not Identity. if ((processing_method == ProcessingMethod::CAST_ONLY || processing_method == ProcessingMethod::CAST_DBIAS) && activation != ActivationKind::Identity) { GTEST_SKIP(); } - // Skips Act tests if the Activation is an identity + // Skip activation tests when the activation type is Identity. if ((processing_method == ProcessingMethod::CAST_DBIAS_DACT || processing_method == ProcessingMethod::CAST_DACT || processing_method == ProcessingMethod::CAST_ACT) && (activation == ActivationKind::Identity)) { diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 6e314bfdf0..380aa520bd 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -106,6 +106,9 @@ __device__ __forceinline__ size_t get_tensor_rows_num( rows_num = static_cast(first_dims_ptr[tensor_id]); break; } + if (rows_num % 128 != 0) { + NVTE_DEVICE_ERROR("First dimension of each tensor in a group must be divisible by 128."); + } return rows_num; } @@ -144,10 +147,6 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } - if (global_dim_X % CHUNK_DIM_X != 0) { - NVTE_DEVICE_ERROR( - "The grouped tensor must be divisible by 128x128 tiles without a tail tile."); - } asm volatile( "{\n\t" @@ -800,22 +799,28 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - if (!is_single_tensor) { + size_t blocks = 0; + if (is_single_tensor) { + const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); + blocks = blocks_Y * blocks_X; + } else { NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); + // Only full tiles supported + NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); + blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } - - NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); - const dim3 grid(elts_total / ELTS_PER_CHUNK); const size_t block_size = THREADS_PER_CHUNK; + const dim3 grid(blocks); const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { NVTE_CHECK(first_logical_dim % 128 == 0, - "First dimension of a grouped tensor should be divisible by 128."); + "First logical dimension of a grouped tensor must be divisible by 128."); } const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); From 39794e2dd47c9cdb74fb0b8d9512b886cc023364 Mon Sep 17 00:00:00 2001 From: Oleg Goncharov Date: Fri, 13 Feb 2026 14:28:33 +0000 Subject: [PATCH 5/5] Added notes on group tensor restrictions into documentation Signed-off-by: Oleg Goncharov --- .../include/transformer_engine/activation.h | 10 ++++++++ .../common/include/transformer_engine/cast.h | 25 ++++++++++++------- 2 files changed, 26 insertions(+), 9 deletions(-) diff --git a/transformer_engine/common/include/transformer_engine/activation.h b/transformer_engine/common/include/transformer_engine/activation.h index 4c9eed3365..482ff64ccb 100644 --- a/transformer_engine/common/include/transformer_engine/activation.h +++ b/transformer_engine/common/include/transformer_engine/activation.h @@ -55,6 +55,7 @@ void nvte_gelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -75,6 +76,7 @@ void nvte_silu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the SiLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -95,6 +97,7 @@ void nvte_relu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -115,6 +118,7 @@ void nvte_qgelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Quick GeLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -135,6 +139,7 @@ void nvte_srelu(const NVTETensor input, NVTETensor output, cudaStream_t stream); /*! \brief Computes the Squared ReLU activation of the grouped input. * If the scaling mode of the grouped output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor for activation. * \param[in,out] output Output grouped tensor. @@ -157,6 +162,7 @@ void nvte_dgelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -181,6 +187,7 @@ void nvte_dsilu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the SiLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -205,6 +212,7 @@ void nvte_drelu(const NVTETensor grad, const NVTETensor input, NVTETensor output /*! \brief Computes the ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -229,6 +237,7 @@ void nvte_dqgelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Quick GeLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. @@ -253,6 +262,7 @@ void nvte_dsrelu(const NVTETensor grad, const NVTETensor input, NVTETensor outpu /*! \brief Computes the Squared ReLU activation gradient of the grouped input. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] grad Incoming grouped gradient. * \param[in] input Input grouped tensor for activation. diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 95d01fd8bf..755052d6dd 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -92,6 +92,7 @@ void nvte_quantize(const NVTETensor input, NVTETensor output, cudaStream_t strea /*! \brief Casts input grouped tensor to MXFP8. * The type of quantized tensor in the output depends on the scaling mode of the output * tensor. See file level comments. + * For grouped tensors with a varying last dimension, the last dimension must be a multiple of 128. * * \param[in] input Input grouped tensor to be cast. * \param[in,out] output Output grouped MXFP8 tensor. @@ -146,6 +147,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d /*! \brief Casts input grouped tensor to MXFP8. Additionally, reduces the input along columns. * If the scaling mode of the output tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -190,6 +192,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -237,6 +240,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the SiLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -284,6 +288,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu * Additionally, reduces the result of the ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -331,6 +336,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Quick GeLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -378,6 +384,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp * Additionally, reduces the result of the Squared ReLU backward along columns. * If the scaling mode of the output grouped tensor is set to NVTE_MXFP8_1D_SCALING, * the block quantization (MXFP8) of the specified shape of the block will be used. + * Grouped dbias is not yet supported for grouped tensors with a varying last dimension. * * This function produces 2 results: * - `output` is equal to `cast(dact(input))` @@ -412,11 +419,11 @@ void nvte_dequantize(const NVTETensor input, NVTETensor output, cudaStream_t str /*! \brief Casts multiple input tensors to quantized output tensors. * - * \param[in] inputs List of input tensors to be cast. - * \param[in,out] outputs List of output quantized tensors. - * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] num_tensors Number of input and output tensors. - * \param[in] stream CUDA stream used for the operation. + * \param[in] inputs List of input tensors to be cast. + * \param[in,out] outputs List of output quantized tensors. + * \param[in] quant_config (Optional) Quantization configurations. + * \param[in] num_tensors Number of input and output tensors. + * \param[in] stream CUDA stream used for the operation. */ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, const NVTEQuantizationConfig quant_config, const size_t num_tensors, @@ -425,11 +432,11 @@ void nvte_multi_tensor_quantize(const NVTETensor *inputs, NVTETensor *outputs, /*! \brief Casts grouped input tensor to quantized output tensors. * * \param[in] input Input tensor to be cast. - * \param[in,out] outputs Output quantized tensors. - * \param[in] split_sections Split sections of the input tensor. - * \param[in] num_tensors Number of output tensors. + * \param[in,out] outputs Output quantized tensors. + * \param[in] split_sections Split sections of the input tensor. + * \param[in] num_tensors Number of output tensors. * \param[in] quant_config (Optional) Quantization configurations. - * \param[in] stream CUDA stream used for the operation. + * \param[in] stream CUDA stream used for the operation. */ void nvte_group_nvfp4_quantize_with_amax(const NVTETensor input, NVTETensor *outputs, const size_t *split_sections, size_t num_tensors,