Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
244 changes: 135 additions & 109 deletions transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include "../../util/ptx.cuh"
#include "../../utils.cuh"
#include "../core/common.cuh"
#include "swizzle.cuh"

namespace transformer_engine {
namespace dispatch {
Expand Down Expand Up @@ -231,7 +232,7 @@ __device__ __forceinline__ void fence_acquire_tensormap(const CUtensorMap *tenso

template <bool IS_DBIAS, bool IS_DACT, bool IS_ACT, typename ParamOP,
float (*OP)(float, const ParamOP &), typename IType, typename OType, bool ROWWISE_SCALING,
bool COLWISE_SCALING>
bool COLWISE_SCALING, bool WITH_GEMM_SWIZZLED_SCALES>
__global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel(
const __grid_constant__ CUtensorMap tensor_map_input_static,
const __grid_constant__ CUtensorMap tensor_map_act_input_static,
Expand All @@ -250,6 +251,8 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
using IType2 = typename ptx::FPx2<IType>;
using OType2 = typename ptx::FPx2<OType>;

using transformer_engine::dispatch::mxfp8::swizzle::gemm_swizzled_scale_idx;

if constexpr (NO_ACTIVATIONS) {
if (noop != nullptr && noop[0] == 1.0f) {
return;
Expand Down Expand Up @@ -475,8 +478,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel

const size_t global_scales_offset_Y = scales_offset_Y_colwise + stage;
const size_t global_scales_offset_X = scales_offset_X_colwise;
const size_t scale_idx =
global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;

size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
Comment on lines +482 to 489
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Type inconsistency: scale_idx is declared as size_t in the colwise path but as int in the rowwise path (line 615). Should use consistent type (size_t) in both paths.

Suggested change
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}
scales_colwise[scale_idx] = biased_exponent;
size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(global_scales_offset_X, global_scales_offset_Y,
DIVUP(rows, static_cast<size_t>(128)));
} else {
scale_idx = global_scales_offset_Y * scale_stride_colwise + global_scales_offset_X;
}


const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
Expand Down Expand Up @@ -602,7 +611,14 @@ __global__ void __launch_bounds__(THREADS_PER_CHUNK) group_quantize_mxfp8_kernel
ptx::float_to_e8m0(thread_amax * Quantized_Limits<OType>::max_norm_rcp);
const int stage_scales_offset_Y = scales_offset_Y_rowwise + stage_offset_Y;
const int stage_scales_offset_X = scales_offset_X_rowwise;
const int scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;

size_t scale_idx = 0;
if constexpr (WITH_GEMM_SWIZZLED_SCALES) {
scale_idx = gemm_swizzled_scale_idx(stage_scales_offset_Y, stage_scales_offset_X,
DIVUP(cols, static_cast<size_t>(128)));
} else {
scale_idx = stage_scales_offset_Y * scale_stride_rowwise + stage_scales_offset_X;
}
scales_rowwise[scale_idx] = biased_exponent;

const float block_scale_inverse = ptx::exp2f_rcp(biased_exponent);
Expand Down Expand Up @@ -803,6 +819,8 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
const dim3 grid(blocks);
const size_t block_size = THREADS_PER_CHUNK;

const bool with_gemm_swizzled_scales = output->with_gemm_swizzled_scales;

// Logical shape of a tensor with varying all dims is [1, M*K]
if (shape_rep != ShapeRepresentation::VARYING_BOTH_DIMS) {
NVTE_CHECK(first_logical_dim % 128 == 0,
Expand Down Expand Up @@ -848,111 +866,119 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations
input->dtype(), IType,
TRANSFORMER_ENGINE_TYPE_SWITCH_FP8ONLY(
output->dtype(), OType,

alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};

constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;

create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim, last_logical_dim,
BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0, input_type_bit_size);

if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}

if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}

if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}

constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);

constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;

const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;

const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

auto kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false>;
break;
}
case ScalingType::COLWISE: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel = group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true>;
break;
}
}

// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);

const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;

OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;

OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim, last_logical_dim,
offsets_ptr, first_dims_ptr, last_dims_ptr, use_rowwise_scaling,
use_colwise_scaling, IS_DACT);
}

NVTE_CHECK_CUDA(cudaFuncSetAttribute(kernel, cudaFuncAttributeMaxDynamicSharedMemorySize,
dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);

if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}

NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
TRANSFORMER_ENGINE_SWITCH_CONDITION(
with_gemm_swizzled_scales, WITH_GEMM_SWIZZLED_SCALES,

alignas(64) CUtensorMap tensor_map_input{};
alignas(64) CUtensorMap tensor_map_act_input{};
alignas(64) CUtensorMap tensor_map_output_rowwise{};
alignas(64) CUtensorMap tensor_map_output_colwise{};

constexpr size_t input_type_bit_size = TypeInfo<IType>::size;
constexpr size_t output_type_bit_size = TypeInfo<OType>::size;

create_2D_tensor_map(tensor_map_input, input->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);

if constexpr (IS_DACT) {
create_2D_tensor_map(tensor_map_act_input, activations->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
input_type_bit_size);
}

if (use_rowwise_scaling) {
create_2D_tensor_map(tensor_map_output_rowwise, output->data, first_logical_dim,
last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X, last_logical_dim, 0,
output_type_bit_size);
}

if (use_colwise_scaling) {
create_2D_tensor_map(tensor_map_output_colwise, output->columnwise_data,
first_logical_dim, last_logical_dim, BUFF_DIM_Y, BUFF_DIM_X,
last_logical_dim, 0, output_type_bit_size);
}

constexpr size_t buff_elems = BUFF_DIM_Y * BUFF_DIM_X;
constexpr size_t buff_elems_total = BUFFS_NUM * buff_elems;
constexpr size_t input_buff_size = (buff_elems_total * input_type_bit_size) / 8;
constexpr size_t output_buff_size = (buff_elems_total * output_type_bit_size) / 8;
constexpr size_t buff_size_aligned_in =
DIVUP_TO_MULTIPLE(input_buff_size, TMA_SHMEM_ALIGNMENT);
constexpr size_t buff_size_aligned_out =
DIVUP_TO_MULTIPLE(output_buff_size, TMA_SHMEM_ALIGNMENT);

constexpr size_t elt_input_mem = buff_size_aligned_in;
constexpr size_t act_input_mem = (IS_DACT ? buff_size_aligned_in : 0);
constexpr size_t in_mem = elt_input_mem + act_input_mem;

const size_t out_rowwise_mem = (use_rowwise_scaling ? buff_size_aligned_out : 0);
const size_t out_colwise_mem = (use_colwise_scaling ? buff_size_aligned_out : 0);
const size_t out_mem = out_rowwise_mem + out_colwise_mem;

const size_t dshmem_size = in_mem + out_mem + TMA_SHMEM_ALIGNMENT;

auto kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType, OType,
true, true, WITH_GEMM_SWIZZLED_SCALES>;
switch (scaling_type) {
case ScalingType::ROWWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, false, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::COLWISE: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, false, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
case ScalingType::BIDIMENSIONAL: {
kernel =
group_quantize_mxfp8_kernel<IS_DBIAS, IS_DACT, IS_ACT, ParamOP, OP, IType,
OType, true, true, WITH_GEMM_SWIZZLED_SCALES>;
break;
}
}

// Update tensor descriptors before launching the kernel
if (!is_single_tensor) {
const IType *const input_dptr = reinterpret_cast<const IType *>(input->data.dptr);

const IType *const act_input_dptr =
IS_DACT ? reinterpret_cast<const IType *>(activations->data.dptr) : nullptr;

OType *const output_rowwise_dptr =
use_rowwise_scaling ? reinterpret_cast<OType *>(output->data.dptr) : nullptr;

OType *const output_colwise_dptr =
use_colwise_scaling ? reinterpret_cast<OType *>(output->columnwise_data.dptr)
: nullptr;
update_tma_descriptors<IType, OType><<<num_tensors, 32, 0, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, input_dptr, act_input_dptr, output_rowwise_dptr,
output_colwise_dptr, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr,
use_rowwise_scaling, use_colwise_scaling, IS_DACT);
}

NVTE_CHECK_CUDA(cudaFuncSetAttribute(
kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, dshmem_size));

kernel<<<grid, block_size, dshmem_size, stream>>>(
tensor_map_input, tensor_map_act_input, tensor_map_output_rowwise,
tensor_map_output_colwise, shape_rep, num_tensors, first_logical_dim,
last_logical_dim, offsets_ptr, first_dims_ptr, last_dims_ptr, scales_rowwise_ptr,
scales_colwise_ptr, noop_ptr, workspace_ptr, amax_ptr);

if constexpr (IS_DBIAS) {
common::reduce_dbias<IType>(workspace_ptr, dbias, dbias_rows, dbias_cols, stream);
}

NVTE_CHECK_CUDA(cudaGetLastError());); // NOLINT(*)
); // NOLINT(*)
); // NOLINT(*)
}

} // namespace mxfp8
Expand Down
7 changes: 7 additions & 0 deletions transformer_engine/common/common.h
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,12 @@ struct GroupedTensor {

NVTEGroupedTensor nvte_tensor;

/*! \brief Whether scaling factors are in format expected by GEMM
*
* Only meaningful for MXFP8 and NVFP4.
*/
bool with_gemm_swizzled_scales = false;

GroupedTensor(NVTEScalingMode scaling_mode, size_t num_tensors)
: data(),
columnwise_data(),
Expand Down Expand Up @@ -401,6 +407,7 @@ struct GroupedTensor {
num_tensors = 0;
scaling_mode = NVTE_DELAYED_TENSOR_SCALING;
nvte_tensor = 0;
with_gemm_swizzled_scales = false;
}
};

Expand Down
Loading