diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 7801a2064d..a29a09836e 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -21,6 +21,7 @@ #include "../../util/ptx.cuh" #include "../../utils.cuh" #include "../core/common.cuh" +#include "swizzle.cuh" namespace transformer_engine { namespace dispatch { @@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso template + bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES> __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel( const __grid_constant__ CUtensorMap tensor_map_input_static, const __grid_constant__ CUtensorMap tensor_map_act_input_static, @@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel using IType2 = typename ptx::FPx2; using OType2 = typename ptx::FPx2; + using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx; + if constexpr (NO_ACTIVATIONS) { if (noop != nullptr && noop[0] == 1.0f) { return; @@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage; const size_t global_scales_offset_X = scales_offset_X_colwise; - const size_t scale_idx = - global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y, + DIVUP(rows, static_cast(128))); + } else { + scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X; + } scales_colwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel ptx::float_to_e8m0(thread_amax * Quantized_Limits::max_norm_rcp); const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y; const int stage_scales_offset_X = scales_offset_X_rowwise; - const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + + size_t scale_idx = 0; + if constexpr (WITH_GEMM_SWIZZLED_SCALES) { + scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X, + DIVUP(cols, static_cast(128))); + } else { + scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X; + } scales_rowwise[scale_idx] = biased_exponent; const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent); @@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations const dim3 grid(blocks); const size_t block_size = THREADS_PER_CHUNK; + const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales; + // Logical shape of a tensor with varying all dims is [1, M*K] if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) { NVTE_CHECK(first_logical_dim % 128 == 0, @@ -848,111 +866,119 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations input->dtype(), IType, TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY( output->dtype(), OType, - - alignas(64) CUtensorMap tensor_map_input{}; - alignas(64) CUtensorMap tensor_map_act_input{}; - alignas(64) CUtensorMap tensor_map_output_rowwise{}; - alignas(64) CUtensorMap tensor_map_output_colwise{}; - - constexpr size_t input_type_bit_size = TypeInfo::size; - constexpr size_t output_type_bit_size = TypeInfo::size; - - create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim, - BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size); - - if constexpr (IS_DACT) { - create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - input_type_bit_size); - } - - if (use_rowwise_scaling) { - create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, - last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, - output_type_bit_size); - } - - if (use_colwise_scaling) { - create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, - first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, - last_logical_dim, 0, output_type_bit_size); - } - - constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; - constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; - constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; - constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; - constexpr size_t buff_size_aligned_in = - DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); - constexpr size_t buff_size_aligned_out = - DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); - - constexpr size_t elt_input_mem = buff_size_aligned_in; - constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); - constexpr size_t in_mem = elt_input_mem + act_input_mem; - - const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); - const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); - const size_t out_mem = out_rowwise_mem + out_colwise_mem; - - const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; - - auto kernel = group_quantize_mxfp8_kernel; - switch (scaling_type) { - case ScalingType::ROWWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::COLWISE: { - kernel = group_quantize_mxfp8_kernel; - break; - } - case ScalingType::BIDIMENSIONAL: { - kernel = group_quantize_mxfp8_kernel; - break; - } - } - - // Update tensor descriptors before launching the kernel - if (!is_single_tensor) { - const IType *const input_dptr = reinterpret_cast(input->data.dptr); - - const IType *const act_input_dptr = - IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; - - OType *const output_rowwise_dptr = - use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; - - OType *const output_colwise_dptr = - use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) - : nullptr; - update_tma_descriptors<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, - output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim, - offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling, - use_colwise_scaling, IS_DACT); - } - - NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, - dshmem_size)); - - kernel<<>>( - tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, - tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, - last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, - scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); - - if constexpr (IS_DBIAS) { - common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); - } - - NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) - ); // NOLINT(*) + TRANSFORMER_ENGINE_SWITCH_CONDITION( + with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES, + + alignas(64) CUtensorMap tensor_map_input{}; + alignas(64) CUtensorMap tensor_map_act_input{}; + alignas(64) CUtensorMap tensor_map_output_rowwise{}; + alignas(64) CUtensorMap tensor_map_output_colwise{}; + + constexpr size_t input_type_bit_size = TypeInfo::size; + constexpr size_t output_type_bit_size = TypeInfo::size; + + create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + + if constexpr (IS_DACT) { + create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + input_type_bit_size); + } + + if (use_rowwise_scaling) { + create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim, + last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, + output_type_bit_size); + } + + if (use_colwise_scaling) { + create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data, + first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, + last_logical_dim, 0, output_type_bit_size); + } + + constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X; + constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems; + constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8; + constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8; + constexpr size_t buff_size_aligned_in = + DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT); + constexpr size_t buff_size_aligned_out = + DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT); + + constexpr size_t elt_input_mem = buff_size_aligned_in; + constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0); + constexpr size_t in_mem = elt_input_mem + act_input_mem; + + const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0); + const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0); + const size_t out_mem = out_rowwise_mem + out_colwise_mem; + + const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT; + + auto kernel = + group_quantize_mxfp8_kernel; + switch (scaling_type) { + case ScalingType::ROWWISE: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + case ScalingType::COLWISE: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + case ScalingType::BIDIMENSIONAL: { + kernel = + group_quantize_mxfp8_kernel; + break; + } + } + + // Update tensor descriptors before launching the kernel + if (!is_single_tensor) { + const IType *const input_dptr = reinterpret_cast(input->data.dptr); + + const IType *const act_input_dptr = + IS_DACT ? reinterpret_cast(activations->data.dptr) : nullptr; + + OType *const output_rowwise_dptr = + use_rowwise_scaling ? reinterpret_cast(output->data.dptr) : nullptr; + + OType *const output_colwise_dptr = + use_colwise_scaling ? reinterpret_cast(output->columnwise_data.dptr) + : nullptr; + update_tma_descriptors<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr, + output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, + use_rowwise_scaling, use_colwise_scaling, IS_DACT); + } + + NVTE_CHECK_CUDA(cudaFuncSetAttribute( + kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size)); + + kernel<<>>( + tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise, + tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim, + last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr, + scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr); + + if constexpr (IS_DBIAS) { + common::reduce_dbias(workspace_ptr, dbias, dbias_rows, dbias_cols, stream); + } + + NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*) + ); // NOLINT(*) + ); // NOLINT(*) } } // namespace mxfp8 diff --git a/transformer_engine/common/common.h b/transformer_engine/common/common.h index 99a2985d5e..2d7f0e7e8c 100644 --- a/transformer_engine/common/common.h +++ b/transformer_engine/common/common.h @@ -335,6 +335,12 @@ struct GroupedTensor { NVTEGroupedTensor nvte_tensor; + /*! \brief Whether scaling factors are in format expected by GEMM + * + * Only meaningful for MXFP8 and NVFP4. + */ + bool with_gemm_swizzled_scales = false; + GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors) : data(), columnwise_data(), @@ -401,6 +407,7 @@ struct GroupedTensor { num_tensors = 0; scaling_mode = NVTE_DELAYED_TENSOR_SCALING; nvte_tensor = 0; + with_gemm_swizzled_scales = false; } };