Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 56 additions & 58 deletions tests/cpp/operator/test_cast_mxfp8_grouped.cu
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,7 @@ void compute_ref(const ProcessingMethod processing_method,
const size_t rows,
const size_t cols,
const size_t scales_stride_rowwise,
const size_t scales_stride_colwise,
const bool is_single_tensor)
const size_t scales_stride_colwise)
{
const size_t tile_size_Y = 32;
const size_t tile_size_X = 32;
Expand Down Expand Up @@ -169,10 +168,8 @@ void compute_ref(const ProcessingMethod processing_method,
}
}

if (is_single_tensor) {
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
for (size_t j = 0; j < cols; ++j) {
output_dbias[j] = static_cast<InputType>(output_dbias_fp32[j]);
}
}

Expand Down Expand Up @@ -250,19 +247,24 @@ void performTest(const ProcessingMethod processing_method,
DType itype = TypeInfo<InputType>::dtype;
DType otype = TypeInfo<OutputType>::dtype;

const bool compute_dbias = (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT);

const size_t rows = logical_shape_vec[0];
const size_t cols = logical_shape_vec[1];

size_t elts_num = 0;
size_t rowwise_sfs_num = 0;
size_t colwise_sfs_num = 0;
size_t sum_of_last_dims = 0;

std::vector<size_t> rowwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> rowwise_scales_offset(num_tensors + 1, 0);
std::vector<size_t> colwise_scales_first_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_last_dim(num_tensors, 0);
std::vector<size_t> colwise_scales_offset(num_tensors + 1, 0);
std::vector<size_t> dbias_offsets(num_tensors + 1, 0);

for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
Expand All @@ -285,13 +287,13 @@ void performTest(const ProcessingMethod processing_method,

rowwise_sfs_num += rowwise_sfs;
colwise_sfs_num += colwise_sfs;
sum_of_last_dims += K;

rowwise_scales_offset[t+1] = rowwise_sfs_num;
colwise_scales_offset[t+1] = colwise_sfs_num;
dbias_offsets[t+1] = sum_of_last_dims;
}

const bool is_single_tensor = (shape_rep == SAME_BOTH_DIMS) || (shape_rep == VARYING_FIRST_DIM);

std::vector<size_t> scales_rowwise_shape = {rowwise_sfs_num};
std::vector<size_t> scales_colwise_shape = {colwise_sfs_num};

Expand All @@ -311,7 +313,7 @@ void performTest(const ProcessingMethod processing_method,
std::vector<fp8e8m0> out_scales_rowwise_ref(rowwise ? rowwise_sfs_num : 0);
std::vector<fp8e8m0> out_scales_colwise_ref(colwise ? colwise_sfs_num : 0);

std::vector<InputType> ref_output_dbias(is_single_tensor ? cols : 0);
std::vector<InputType> ref_output_dbias(sum_of_last_dims, static_cast<InputType>(0.0f));

for (size_t i = 0; i < elts_num; ++i) {
const float val = dis(gen);
Expand All @@ -336,6 +338,7 @@ void performTest(const ProcessingMethod processing_method,

const size_t in_data_size = elts_num * sizeof(InputType);
const size_t out_data_size = elts_num * sizeof(OutputType);
const size_t dbias_data_size = sum_of_last_dims * sizeof(InputType);
const size_t rowwise_scales_size = rowwise_sfs_num * sizeof(fp8e8m0);
const size_t colwise_scales_size = colwise_sfs_num * sizeof(fp8e8m0);

Expand All @@ -345,6 +348,7 @@ void performTest(const ProcessingMethod processing_method,

InputType* grad_data_d;
InputType* in_data_d;
InputType* dbias_out_data_d;
OutputType* out_data_rowwise_d;
OutputType* out_data_colwise_d;
fp8e8m0* out_scales_rowwise_d;
Expand All @@ -367,6 +371,10 @@ void performTest(const ProcessingMethod processing_method,

NVTEShape logical_shape_ = nvte_make_shape(logical_shape_vec.data(), logical_shape_vec.size());

std::vector<size_t> dbias_logical_shape_vec= {num_tensors, cols};
NVTEShape dbias_logical_shape_ = nvte_make_shape(dbias_logical_shape_vec.data(),
dbias_logical_shape_vec.size());

NVTEShape first_dims_shape_;
NVTEShape last_dims_shape_;
NVTEShape offsets_shape_;
Expand All @@ -382,11 +390,12 @@ void performTest(const ProcessingMethod processing_method,
NVTEGroupedTensor grad_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor in_group_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor out_group_tensor = nvte_create_grouped_tensor(NVTE_MXFP8_1D_SCALING, num_tensors, logical_shape_);
NVTEGroupedTensor output_dbias_tensor = nvte_create_grouped_tensor(NVTE_DELAYED_TENSOR_SCALING, num_tensors, dbias_logical_shape_);

NVTEBasicTensor grad_data_tensor = {grad_data_d, static_cast<NVTEDType>(itype), logical_shape_};
NVTEBasicTensor in_data_tensor = {in_data_d, static_cast<NVTEDType>(itype), logical_shape_};
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);
nvte_set_grouped_tensor_param(&grad_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &grad_data_tensor);
nvte_set_grouped_tensor_param(&in_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &in_data_tensor);

if ((shape_rep == VARYING_FIRST_DIM) || (shape_rep == VARYING_BOTH_DIMS)) {
NVTEBasicTensor first_dims_tensor = {first_dims_d, kNVTEInt64, first_dims_shape_};
Expand Down Expand Up @@ -433,52 +442,40 @@ void performTest(const ProcessingMethod processing_method,
nvte_set_grouped_tensor_param(&out_group_tensor, NVTEGroupedTensorParam::kNVTEGroupedColumnwiseScaleInv, &out_scales_colwise_tensor);
}

Tensor output_dbias("output_dbias", std::vector<size_t>{ cols }, itype);
if (compute_dbias) {
cudaMalloc((void**)&dbias_out_data_d, dbias_data_size);
cudaMemset(dbias_out_data_d, 0, dbias_data_size);
NVTEBasicTensor output_dbias_data_tensor = {dbias_out_data_d, static_cast<NVTEDType>(itype), dbias_logical_shape_};
nvte_set_grouped_tensor_param(&output_dbias_tensor, NVTEGroupedTensorParam::kNVTEGroupedRowwiseData, &output_dbias_data_tensor);
}

// Reference (CPU)
if (is_single_tensor) {

const size_t unpadded_rowwise_blocks_X = divide_round_up(cols, 32);
const size_t unpadded_colwise_blocks_X = cols;
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];

const size_t scales_stride_rowwise = round_up_to_nearest_multiple(unpadded_rowwise_blocks_X, 4);
const size_t scales_stride_colwise = round_up_to_nearest_multiple(unpadded_colwise_blocks_X, 128);
const size_t scales_stride_rowwise = rowwise_scales_last_dim[t];
const size_t scales_stride_colwise = colwise_scales_last_dim[t];
const size_t data_offset = offsets_h[t];
const size_t rowwise_sfs_offset = rowwise_scales_offset[t];
const size_t colwise_sfs_offset = colwise_scales_offset[t];
const size_t dbias_offset = dbias_offsets[t];

const InputType* const grad_ptr = grad_data.data() + data_offset;
const InputType* const in_ptr = in_data.data() + data_offset;
OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset;
OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset;
fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset;
fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset;
InputType* const ref_output_dbias_ptr = ref_output_dbias.data() + dbias_offset;

compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_data.data(), grad_data.data(),
out_data_rowwise_ref.data(), out_data_colwise_ref.data(),
out_scales_rowwise_ref.data(), out_scales_colwise_ref.data(),
ref_output_dbias.data(), rows, cols,
processing_method, OP, rowwise, colwise, in_ptr, grad_ptr,
out_data_rowwise_ptr, out_data_colwise_ptr,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias_ptr, M, K,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
} else {
for (size_t t = 0; t < num_tensors; ++t) {
const size_t M = first_dims_h[t];
const size_t K = last_dims_h[t];

const size_t scales_stride_rowwise = rowwise_scales_last_dim[t];
const size_t scales_stride_colwise = colwise_scales_last_dim[t];
const size_t data_offset = offsets_h[t];
const size_t rowwise_sfs_offset = rowwise_scales_offset[t];
const size_t colwise_sfs_offset = colwise_scales_offset[t];

const InputType* const grad_ptr = grad_data.data() + data_offset;
const InputType* const in_ptr = in_data.data() + data_offset;
OutputType* const out_data_rowwise_ptr = out_data_rowwise_ref.data() + data_offset;
OutputType* const out_data_colwise_ptr = out_data_colwise_ref.data() + data_offset;
fp8e8m0* const out_scales_rowwise_ptr = out_scales_rowwise_ref.data() + rowwise_sfs_offset;
fp8e8m0* const out_scales_colwise_ptr = out_scales_colwise_ref.data() + colwise_sfs_offset;

compute_ref<InputType, OutputType>(
processing_method, OP, rowwise, colwise, in_ptr, grad_ptr,
out_data_rowwise_ptr, out_data_colwise_ptr,
out_scales_rowwise_ptr, out_scales_colwise_ptr,
ref_output_dbias.data(), M, K,
scales_stride_rowwise,
scales_stride_colwise,
is_single_tensor);
}
scales_stride_colwise);
}

// GPU
Expand All @@ -489,9 +486,9 @@ void performTest(const ProcessingMethod processing_method,
break;
}
case ProcessingMethod::CAST_DBIAS: {
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias.data(), workspace.data(), 0);
nvte_group_quantize_dbias(grad_group_tensor, out_group_tensor, output_dbias_tensor, workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_DBIAS_DACT: {
Expand All @@ -502,10 +499,10 @@ void performTest(const ProcessingMethod processing_method,
else if (OP == &dsrelu) { nvte_group_quantize_dbias_dact = &nvte_group_quantize_dbias_dsrelu; }

nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
output_dbias_tensor, workspace.data(), 0);
workspace = Tensor("workspace", workspace.rowwise_shape(), workspace.dtype());
nvte_group_quantize_dbias_dact(grad_group_tensor, in_group_tensor, out_group_tensor,
output_dbias.data(), workspace.data(), 0);
output_dbias_tensor, workspace.data(), 0);
break;
}
case ProcessingMethod::CAST_ACT: {
Expand Down Expand Up @@ -566,9 +563,10 @@ void performTest(const ProcessingMethod processing_method,
out_data_colwise_h.data(), rows, cols, false, mismatches_elts);
}

if (processing_method == ProcessingMethod::CAST_DBIAS
|| processing_method == ProcessingMethod::CAST_DBIAS_DACT)
{
if (compute_dbias) {
Tensor output_dbias("output_dbias", std::vector<size_t>{ sum_of_last_dims }, itype);
cudaMemcpy(output_dbias.rowwise_dptr(), dbias_out_data_d, dbias_data_size, cudaMemcpyDeviceToDevice);

auto [atol_dbias, rtol_dbias] = getTolerances(itype);
if (itype == DType::kFloat32) {
atol_dbias = 1e-4;
Expand All @@ -581,6 +579,7 @@ void performTest(const ProcessingMethod processing_method,

cudaFree(grad_data_d);
cudaFree(in_data_d);
cudaFree(dbias_out_data_d);
cudaFree(first_dims_d);
cudaFree(last_dims_d);
cudaFree(offsets_d);
Expand Down Expand Up @@ -628,7 +627,6 @@ std::vector<std::vector<size_t>> input_config = {
{SAME_BOTH_DIMS, 1, 128,128},
{SAME_BOTH_DIMS, 2, 256,128},
{VARYING_FIRST_DIM, 2, 512,128, 128,384},
{VARYING_FIRST_DIM, 2, 384,160, 128,256},
{VARYING_FIRST_DIM, 5, 4096,512, 128,256,384,1024,2304},
{VARYING_LAST_DIM, 3, 256,896, 128,256,512},
{VARYING_BOTH_DIMS, 2, 1,(128*128)+(256*256), 128,256, 128,256},
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/activation/gelu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_dgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dgelu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_dgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dgelu);
using namespace transformer_engine;
Expand Down Expand Up @@ -110,7 +110,7 @@ void nvte_group_dqgelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dqgelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dqgelu(const NVTETensor input, const NVTETensor activat

void nvte_group_quantize_dbias_dqgelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dqgelu);
using namespace transformer_engine;
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/activation/relu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_drelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_drelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_drelu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_drelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_drelu);
using namespace transformer_engine;
Expand Down Expand Up @@ -110,7 +110,7 @@ void nvte_group_dsrelu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inp
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsrelu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -135,7 +135,7 @@ void nvte_quantize_dbias_dsrelu(const NVTETensor input, const NVTETensor activat

void nvte_group_quantize_dbias_dsrelu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsrelu);
using namespace transformer_engine;
Expand Down
4 changes: 2 additions & 2 deletions transformer_engine/common/activation/swiglu.cu
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ void nvte_group_dsilu(const NVTEGroupedTensor grad, const NVTEGroupedTensor inpu
NVTEGroupedTensor output, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_dsilu);
using namespace transformer_engine;
NVTETensor dbias = nullptr;
NVTEGroupedTensor dbias = nullptr;
NVTETensor workspace = nullptr;

constexpr bool IS_DBIAS = false;
Expand All @@ -57,7 +57,7 @@ void nvte_quantize_dbias_dsilu(const NVTETensor input, const NVTETensor activati

void nvte_group_quantize_dbias_dsilu(const NVTEGroupedTensor input,
const NVTEGroupedTensor activation_input,
NVTEGroupedTensor output, NVTETensor dbias,
NVTEGroupedTensor output, NVTEGroupedTensor dbias,
NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias_dsilu);
using namespace transformer_engine;
Expand Down
2 changes: 1 addition & 1 deletion transformer_engine/common/cast/cast.cu
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ void nvte_quantize_dbias(const NVTETensor input, NVTETensor output, NVTETensor d
}

void nvte_group_quantize_dbias(const NVTEGroupedTensor input, NVTEGroupedTensor output,
NVTETensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTEGroupedTensor dbias, NVTETensor workspace, cudaStream_t stream) {
NVTE_API_CALL(nvte_group_quantize_dbias);
using namespace transformer_engine;

Expand Down
Loading
Loading