diff --git a/tests/cpp/operator/test_cast_mxfp8_grouped.cu b/tests/cpp/operator/test_cast_mxfp8_grouped.cu index 8b084ca452..889de78a52 100644 --- a/tests/cpp/operator/test_cast_mxfp8_grouped.cu +++ b/tests/cpp/operator/test_cast_mxfp8_grouped.cu @@ -58,8 +58,7 @@ void compute_ref(const ProcessingMethod processing_method, const size_t rows, const size_t cols, const size_t scales_stride_rowwise, - const size_t scales_stride_colwise, - const bool is_single_tensor) + const size_t scales_stride_colwise) { const size_t tile_size_Y = 32; const size_t tile_size_X = 32; @@ -169,10 +168,8 @@ void compute_ref(const ProcessingMethod processing_method, } } - if (is_single_tensor) { - for (size_t j = 0; j < cols; ++j) { - output_dbias[j] = static_cast(output_dbias_fp32[j]); - } + for (size_t j = 0; j < cols; ++j) { + output_dbias[j] = static_cast(output_dbias_fp32[j]); } } @@ -250,12 +247,16 @@ void performTest(const ProcessingMethod processing_method, DType itype = TypeInfo::dtype; DType otype = TypeInfo::dtype; + const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS + || processing_method == ProcessingMethod::CAST_DBIAS_DACT); + const size_t rows = logical_shape_vec[0]; const size_t cols = logical_shape_vec[1]; size_t elts_num = 0; size_t rowwise_sfs_num = 0; size_t colwise_sfs_num = 0; + size_t sum_of_last_dims = 0; std::vector rowwise_scales_first_dim(num_tensors, 0); std::vector rowwise_scales_last_dim(num_tensors, 0); @@ -263,6 +264,7 @@ void performTest(const ProcessingMethod processing_method, std::vector colwise_scales_first_dim(num_tensors, 0); std::vector colwise_scales_last_dim(num_tensors, 0); std::vector colwise_scales_offset(num_tensors + 1, 0); + std::vector dbias_offsets(num_tensors + 1, 0); for (size_t t = 0; t < num_tensors; ++t) { const size_t M = first_dims_h[t]; @@ -285,13 +287,13 @@ void performTest(const ProcessingMethod processing_method, rowwise_sfs_num += rowwise_sfs; colwise_sfs_num += colwise_sfs; + sum_of_last_dims += K; rowwise_scales_offset[t+1] = rowwise_sfs_num; colwise_scales_offset[t+1] = colwise_sfs_num; + dbias_offsets[t+1] = sum_of_last_dims; } - const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM); - std::vector scales_rowwise_shape = {rowwise_sfs_num}; std::vector scales_colwise_shape = {colwise_sfs_num}; @@ -311,7 +313,7 @@ void performTest(const ProcessingMethod processing_method, std::vector out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0); std::vector out_scales_colwise_ref(colwise ? colwise_sfs_num : 0); - std::vector ref_output_dbias(is_single_tensor ? cols : 0); + std::vector ref_output_dbias(sum_of_last_dims, static_cast(0.0f)); for (size_t i = 0; i < elts_num; ++i) { const float val = dis(gen); @@ -336,6 +338,7 @@ void performTest(const ProcessingMethod processing_method, const size_t in_data_size = elts_num * sizeof(InputType); const size_t out_data_size = elts_num * sizeof(OutputType); + const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType); const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0); const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0); @@ -345,6 +348,7 @@ void performTest(const ProcessingMethod processing_method, InputType* grad_data_d; InputType* in_data_d; + InputType* dbias_out_data_d; OutputType* out_data_rowwise_d; OutputType* out_data_colwise_d; fp8e8m0* out_scales_rowwise_d; @@ -367,6 +371,10 @@ void performTest(const ProcessingMethod processing_method, NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size()); + std::vector dbias_logical_shape_vec= {num_tensors, cols}; + NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(), + dbias_logical_shape_vec.size()); + NVTEShape first_dims_shape_; NVTEShape last_dims_shape_; NVTEShape offsets_shape_; @@ -382,11 +390,12 @@ void performTest(const ProcessingMethod processing_method, NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_); NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_); + NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_); NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast(itype), logical_shape_}; NVTEBasicTensor in_data_tensor = {in_data_d, static_cast(itype), logical_shape_}; - nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor); + nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor); if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) { NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_}; @@ -433,52 +442,40 @@ void performTest(const ProcessingMethod processing_method, nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor); } - Tensor output_dbias("output_dbias", std::vector{ cols }, itype); + if (compute_dbias) { + cudaMalloc((void**)&dbias_out_data_d, dbias_data_size); + cudaMemset(dbias_out_data_d, 0, dbias_data_size); + NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast(itype), dbias_logical_shape_}; + nvte_set_grouped_tensor_param(&output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor); + } // Reference (CPU) - if (is_single_tensor) { - - const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32); - const size_t unpadded_colwise_blocks_X = cols; + for (size_t t = 0; t < num_tensors; ++t) { + const size_t M = first_dims_h[t]; + const size_t K = last_dims_h[t]; - const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4); - const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128); + const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; + const size_t scales_stride_colwise = colwise_scales_last_dim[t]; + const size_t data_offset = offsets_h[t]; + const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; + const size_t colwise_sfs_offset = colwise_scales_offset[t]; + const size_t dbias_offset = dbias_offsets[t]; + + const InputType* const grad_ptr = grad_data.data() + data_offset; + const InputType* const in_ptr = in_data.data() + data_offset; + OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; + OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; + fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; + fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; + InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset; compute_ref( - processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(), - out_data_rowwise_ref.data(), out_data_colwise_ref.data(), - out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(), - ref_output_dbias.data(), rows, cols, + processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, + out_data_rowwise_ptr, out_data_colwise_ptr, + out_scales_rowwise_ptr, out_scales_colwise_ptr, + ref_output_dbias_ptr, M, K, scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } else { - for (size_t t = 0; t < num_tensors; ++t) { - const size_t M = first_dims_h[t]; - const size_t K = last_dims_h[t]; - - const size_t scales_stride_rowwise = rowwise_scales_last_dim[t]; - const size_t scales_stride_colwise = colwise_scales_last_dim[t]; - const size_t data_offset = offsets_h[t]; - const size_t rowwise_sfs_offset = rowwise_scales_offset[t]; - const size_t colwise_sfs_offset = colwise_scales_offset[t]; - - const InputType* const grad_ptr = grad_data.data() + data_offset; - const InputType* const in_ptr = in_data.data() + data_offset; - OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset; - OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset; - fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset; - fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset; - - compute_ref( - processing_method, OP, rowwise, colwise, in_ptr, grad_ptr, - out_data_rowwise_ptr, out_data_colwise_ptr, - out_scales_rowwise_ptr, out_scales_colwise_ptr, - ref_output_dbias.data(), M, K, - scales_stride_rowwise, - scales_stride_colwise, - is_single_tensor); - } + scales_stride_colwise); } // GPU @@ -489,9 +486,9 @@ void performTest(const ProcessingMethod processing_method, break; } case ProcessingMethod::CAST_DBIAS: { - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); - nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0); + nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_DBIAS_DACT: { @@ -502,10 +499,10 @@ void performTest(const ProcessingMethod processing_method, else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; } nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype()); nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor, - output_dbias.data(), workspace.data(), 0); + output_dbias_tensor, workspace.data(), 0); break; } case ProcessingMethod::CAST_ACT: { @@ -566,9 +563,10 @@ void performTest(const ProcessingMethod processing_method, out_data_colwise_h.data(), rows, cols, false, mismatches_elts); } - if (processing_method == ProcessingMethod::CAST_DBIAS - || processing_method == ProcessingMethod::CAST_DBIAS_DACT) - { + if (compute_dbias) { + Tensor output_dbias("output_dbias", std::vector{ sum_of_last_dims }, itype); + cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice); + auto [atol_dbias, rtol_dbias] = getTolerances(itype); if (itype == DType::kFloat32) { atol_dbias = 1e-4; @@ -581,6 +579,7 @@ void performTest(const ProcessingMethod processing_method, cudaFree(grad_data_d); cudaFree(in_data_d); + cudaFree(dbias_out_data_d); cudaFree(first_dims_d); cudaFree(last_dims_d); cudaFree(offsets_d); @@ -628,7 +627,6 @@ std::vector> input_config = { {SAME_BOTH_DIMS, 1, 128,128}, {SAME_BOTH_DIMS, 2, 256,128}, {VARYING_FIRST_DIM, 2, 512,128, 128,384}, - {VARYING_FIRST_DIM, 2, 384,160, 128,256}, {VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304}, {VARYING_LAST_DIM, 3, 256,896, 128,256,512}, {VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256}, diff --git a/transformer_engine/common/activation/gelu.cu b/transformer_engine/common/activation/gelu.cu index d209ea8d47..ea864813bf 100644 --- a/transformer_engine/common/activation/gelu.cu +++ b/transformer_engine/common/activation/gelu.cu @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dgelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dqgelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/relu.cu b/transformer_engine/common/activation/relu.cu index b6f758caf6..fc9122b7ec 100644 --- a/transformer_engine/common/activation/relu.cu +++ b/transformer_engine/common/activation/relu.cu @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_drelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_drelu); using namespace transformer_engine; @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsrelu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu); using namespace transformer_engine; diff --git a/transformer_engine/common/activation/swiglu.cu b/transformer_engine/common/activation/swiglu.cu index 77d5b6867f..12478af4cf 100644 --- a/transformer_engine/common/activation/swiglu.cu +++ b/transformer_engine/common/activation/swiglu.cu @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu NVTEGroupedTensor output, cudaStream_t stream) { NVTE_API_CALL(nvte_group_dsilu); using namespace transformer_engine; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; constexpr bool IS_DBIAS = false; @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor activation_input, - NVTEGroupedTensor output, NVTETensor dbias, + NVTEGroupedTensor output, NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias_dsilu); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/cast.cu b/transformer_engine/common/cast/cast.cu index 582172a88e..f00b34b3d9 100644 --- a/transformer_engine/common/cast/cast.cu +++ b/transformer_engine/common/cast/cast.cu @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d } void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) { + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) { NVTE_API_CALL(nvte_group_quantize_dbias); using namespace transformer_engine; diff --git a/transformer_engine/common/cast/core/common.cuh b/transformer_engine/common/cast/core/common.cuh index 0997b01f7e..a4e033939b 100644 --- a/transformer_engine/common/cast/core/common.cuh +++ b/transformer_engine/common/cast/core/common.cuh @@ -22,6 +22,14 @@ namespace transformer_engine { namespace dispatch { namespace common { + +enum ShapeRepresentation { + SAME_BOTH_DIMS = 0, + VARYING_FIRST_DIM = 1, + VARYING_LAST_DIM = 2, + VARYING_BOTH_DIMS = 3 +}; + inline bool full_tile_1D_tensor(const Tensor *const t, const size_t elems_per_block) { const size_t N = product(t->data.shape); const bool isFullTile = (N % elems_per_block == 0); @@ -78,6 +86,56 @@ __global__ void __launch_bounds__(THREADS_PER_BLOCK) } stg_vec.store_to(thread_out_base); } + +template +__global__ void __launch_bounds__(THREADS_PER_BLOCK) + group_reduce_dbias_kernel(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const offsets_ptr, const int64_t *const first_dims_ptr, + const int64_t *const last_dims_ptr, OType *const dbias_output, + const float *dbias_partial, const size_t chunk_dim_Y) { + using ComputeVec = Vec; + using OutputVec = Vec; + + const size_t tensor_id = blockIdx.y; + const size_t tensor_rows = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (first_logical_dim / num_tensors) + : first_dims_ptr[tensor_id]; + + const size_t rows = tensor_rows / chunk_dim_Y; + const size_t cols = last_logical_dim; + + const size_t dbias_in_offset_Y = (shape_rep == ShapeRepresentation::SAME_BOTH_DIMS) + ? (tensor_id * (tensor_rows / chunk_dim_Y)) + : (offsets_ptr[tensor_id] / cols / chunk_dim_Y); + + const size_t thread_id = blockIdx.x * blockDim.x + threadIdx.x; + + if (thread_id * nvec >= cols) { + return; + } + + const float *const thread_in_base = dbias_partial + dbias_in_offset_Y * cols + thread_id * nvec; + OType *const thread_out_base = dbias_output + tensor_id * cols + thread_id * nvec; + + ComputeVec ldg_vec; + ComputeVec acc_vec; + acc_vec.clear(); + for (int i = 0; i < rows; ++i) { + ldg_vec.load_from(thread_in_base + i * cols); +#pragma unroll + for (int e = 0; e < nvec; ++e) { + acc_vec.data.elt[e] += ldg_vec.data.elt[e]; + } + } + + OutputVec stg_vec; +#pragma unroll + for (int e = 0; e < nvec; ++e) { + stg_vec.data.elt[e] = static_cast(acc_vec.data.elt[e]); + } + stg_vec.store_to(thread_out_base); +} } // namespace kernel template @@ -96,6 +154,32 @@ void reduce_dbias(const float *workspace_ptr, Tensor *dbias, const size_t rows, NVTE_CHECK_CUDA(cudaGetLastError()); } +template +void grouped_reduce_dbias(const ShapeRepresentation shape_rep, const size_t num_tensors, + const size_t first_logical_dim, const size_t last_logical_dim, + const int64_t *const data_tensor_offsets_ptr, + const int64_t *const data_tensor_first_dims_ptr, + const int64_t *const data_tensor_last_dims_ptr, GroupedTensor *dbias, + const float *workspace_ptr, const size_t chunk_dim_Y, + cudaStream_t stream) { + using namespace kernel; + constexpr size_t reduce_dbias_store_bytes = 8; // stg.64 + constexpr size_t reduce_dbias_nvec = reduce_dbias_store_bytes / sizeof(IType); + + NVTE_CHECK(last_logical_dim % reduce_dbias_nvec == 0, "Unsupported shape."); + + const size_t blocks_X = DIVUP(last_logical_dim, THREADS_PER_BLOCK * reduce_dbias_nvec); + const size_t blocks_Y = num_tensors; + const dim3 grid(blocks_X, blocks_Y); + + group_reduce_dbias_kernel<<>>( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, data_tensor_offsets_ptr, + data_tensor_first_dims_ptr, data_tensor_last_dims_ptr, + reinterpret_cast(dbias->data.dptr), workspace_ptr, chunk_dim_Y); + + NVTE_CHECK_CUDA(cudaGetLastError()); +} + } // namespace common } // namespace dispatch } // namespace transformer_engine diff --git a/transformer_engine/common/cast/dispatch/quantize.cuh b/transformer_engine/common/cast/dispatch/quantize.cuh index b83df1dedf..1bb6552349 100644 --- a/transformer_engine/common/cast/dispatch/quantize.cuh +++ b/transformer_engine/common/cast/dispatch/quantize.cuh @@ -380,13 +380,13 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); const NVTEGroupedTensor activation = nullptr; - NVTETensor dbias = nullptr; + NVTEGroupedTensor dbias = nullptr; NVTETensor workspace = nullptr; const GroupedTensor *input_tensor = convertNVTEGroupedTensorCheck(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); const GroupedTensor *activations_tensor = convertNVTEGroupedTensor(activation); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config @@ -417,8 +417,9 @@ void group_quantize_fwd_helper(const NVTEGroupedTensor input, NVTEGroupedTensor template void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTensor input, - NVTEGroupedTensor output, NVTETensor dbias, NVTETensor workspace, - const NVTEQuantizationConfig quant_config, cudaStream_t stream) { + NVTEGroupedTensor output, NVTEGroupedTensor dbias, + NVTETensor workspace, const NVTEQuantizationConfig quant_config, + cudaStream_t stream) { using namespace detail; NVTEScalingMode scaling_mode = nvte_grouped_tensor_scaling_mode(output); @@ -426,7 +427,7 @@ void group_quantize_bwd_helper(const NVTEGroupedTensor grad, const NVTEGroupedTe const GroupedTensor *grad_tensor = convertNVTEGroupedTensorCheck(grad); const GroupedTensor *input_tensor = convertNVTEGroupedTensor(input); GroupedTensor *output_tensor = convertNVTEGroupedTensorCheck(output); - Tensor *dbias_tensor = convertNVTETensor(dbias); + GroupedTensor *dbias_tensor = convertNVTEGroupedTensor(dbias); Tensor *workspace_tensor = convertNVTETensor(workspace); // Quantization config diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a29a09836e..6e314bfdf0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -28,19 +28,14 @@ namespace dispatch { namespace mxfp8 { namespace group_quantize_kernel { +using namespace dispatch::common; + constexpr int MAX_SUPPORTED_TENSOR_DESCRIPTORS = 64; __device__ alignas(128) CUtensorMap g_tensor_maps_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_act_input[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_rowwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; __device__ alignas(128) CUtensorMap g_tensor_maps_output_colwise[MAX_SUPPORTED_TENSOR_DESCRIPTORS]; -enum ShapeRepresentation { - SAME_BOTH_DIMS = 0, - VARYING_FIRST_DIM = 1, - VARYING_LAST_DIM = 2, - VARYING_BOTH_DIMS = 3 -}; - constexpr size_t SCALE_DIM_Y = 32; constexpr size_t SCALE_DIM_X = 32; @@ -144,11 +139,15 @@ __device__ __forceinline__ void modify_base_tensor_map(const CUtensorMap base_te if constexpr (is_blackwell) { const size_t global_stride_bytes = global_dim_X * data_type_size_bytes; if (global_stride_bytes % TMA_GMEM_ALIGNMENT != 0) { - NVTE_DEVICE_ERROR("Shape not supported, as data stride must be 16B aligned."); + NVTE_DEVICE_ERROR("Shape not supported. Data stride must be 16B aligned."); } if (global_data_ptr % TMA_GMEM_ALIGNMENT != 0) { NVTE_DEVICE_ERROR("Tensor data pointer must be 16B aligned"); } + if (global_dim_X % CHUNK_DIM_X != 0) { + NVTE_DEVICE_ERROR( + "The grouped tensor must be divisible by 128x128 tiles without a tail tile."); + } asm volatile( "{\n\t" @@ -749,8 +748,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel template void group_quantize(const GroupedTensor *input, const GroupedTensor *activations, - const Tensor *noop, GroupedTensor *output, Tensor *dbias, Tensor *workspace, - cudaStream_t stream) { + const Tensor *noop, GroupedTensor *output, GroupedTensor *dbias, + Tensor *workspace, cudaStream_t stream) { using namespace group_quantize_kernel; checkCuDriverContext(stream); @@ -801,22 +800,14 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const size_t num_tensors = input->num_tensors; - size_t blocks = 0; - - if (is_single_tensor) { - const size_t blocks_Y = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t blocks_X = DIVUP(last_logical_dim, CHUNK_DIM_X); - blocks = blocks_Y * blocks_X; - } else { - NVTE_CHECK(num_tensors < MAX_SUPPORTED_TENSOR_DESCRIPTORS, + if (!is_single_tensor) { + NVTE_CHECK(num_tensors <= MAX_SUPPORTED_TENSOR_DESCRIPTORS, "Number of tensors in a group is larger than " "the MAX number of supported descriptors (64)."); - // Only full tiles supported - NVTE_CHECK(last_logical_dim % CHUNK_DIM_X == 0, - "Last dimension of a grouped tensor should be divisible by 128."); - blocks = DIVUP(elts_total, CHUNK_DIM_Y * CHUNK_DIM_X); } - const dim3 grid(blocks); + + NVTE_CHECK(elts_total % ELTS_PER_CHUNK == 0, "Only full-tile grouped tensors supported."); + const dim3 grid(elts_total / ELTS_PER_CHUNK); const size_t block_size = THREADS_PER_CHUNK; const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; @@ -845,18 +836,20 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations NVTE_CHECK(scales_colwise_ptr != nullptr, "Columnwise scaling tensor must be allocated"); } - const size_t dbias_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); - const size_t dbias_cols = last_logical_dim; if constexpr (IS_DBIAS) { NVTE_CHECK(is_single_tensor, "DBias is only supported for tensors with the const last dimension."); NVTE_CHECK(dbias->data.dtype == input->dtype(), "DBias must have the same type as input_tensor."); - NVTE_CHECK(dbias->data.shape == std::vector{last_logical_dim}, "Wrong shape of DBias."); - NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + std::vector expected_shape_dbias_tensor = {num_tensors, last_logical_dim}; + NVTE_CHECK(dbias->data.shape == expected_shape_dbias_tensor, "Wrong shape of DBias."); + + NVTE_CHECK(workspace != nullptr, "Workspace must be a tensor."); + const size_t dbias_workspace_rows = DIVUP(first_logical_dim, CHUNK_DIM_Y); + const size_t dbias_workspace_cols = last_logical_dim; if (workspace->data.dptr == nullptr) { - workspace->data.shape = {dbias_rows, dbias_cols}; + workspace->data.shape = {dbias_workspace_rows, dbias_workspace_cols}; workspace->data.dtype = DType::kFloat32; return; } @@ -973,7 +966,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + common::grouped_reduce_dbias( + shape_rep, num_tensors, first_logical_dim, last_logical_dim, offsets_ptr, + first_dims_ptr, last_dims_ptr, dbias, workspace_ptr, CHUNK_DIM_Y, stream); } NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) diff --git a/transformer_engine/common/include/transformer_engine/cast.h b/transformer_engine/common/include/transformer_engine/cast.h index 04712d3003..95d01fd8bf 100644 --- a/transformer_engine/common/include/transformer_engine/cast.h +++ b/transformer_engine/common/include/transformer_engine/cast.h @@ -161,7 +161,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d * \param[in] stream CUDA stream used for the operation. */ void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream); /*! \brief Computes backward of GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the GeLU backward along columns. @@ -207,7 +207,8 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of SiLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the SiLU backward along columns. @@ -253,7 +254,8 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the ReLU backward along columns. @@ -299,7 +301,8 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor act_inpu */ void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Quick GeLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Quick GeLU backward along columns. @@ -345,7 +348,8 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Computes backward of Squared ReLU operation on the input, then casts to FP8/MXFP8. * Additionally, reduces the result of the Squared ReLU backward along columns. @@ -391,7 +395,8 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor act_inp */ void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input, const NVTEGroupedTensor act_input, NVTEGroupedTensor output, - NVTETensor dbias, NVTETensor workspace, cudaStream_t stream); + NVTEGroupedTensor dbias, NVTETensor workspace, + cudaStream_t stream); /*! \brief Casts input tensor from reduced to higher precision. * If the scaling mode of the input tensor is set to NVTE_MXFP8_1D_SCALING,