Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 11 additions & 9 deletions tests/cpp/operator/test_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -88,16 +88,16 @@ struct TestParams {
std::vector<std::tuple<size_t, size_t, size_t>> make_shapes(ShapeCase scase) {
switch (scase) {
case ShapeCase::kAllSame:
return {{64, 64, 32}, {64, 64, 32}, {64, 64, 32}};
return {{128, 256, 384}, {128, 256, 384}, {128, 256, 384}};
case ShapeCase::kSameFirst:
// Same M (first dim), varying N and K
return {{64, 80, 32}, {64, 96, 48}, {64, 112, 64}};
return {{128, 256, 384}, {128, 384, 512}, {128, 512, 640}};
case ShapeCase::kSameLast:
// Same N (last dim), varying M and K
return {{64, 80, 32}, {80, 80, 48}, {96, 80, 64}};
return {{128, 256, 384}, {256, 256, 512}, {384, 256, 640}};
case ShapeCase::kAllDifferent:
default:
return {{64, 96, 32}, {80, 112, 48}, {96, 128, 64}};
return {{128, 256, 384}, {256, 384, 512}, {384, 512, 640}};
}
}

Expand All @@ -123,10 +123,11 @@ void run_grouped_gemm_case(const TestParams& params) {

for (size_t i = 0; i < num_gemms; ++i) {
const auto [M, N, K] = shapes[i];
const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{M, K}
: std::vector<size_t>{K, M};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, N}
: std::vector<size_t>{N, K};

const std::vector<size_t> a_shape = params.transa ? std::vector<size_t>{N, K}
: std::vector<size_t>{K, N};
const std::vector<size_t> b_shape = params.transb ? std::vector<size_t>{K, M}
: std::vector<size_t>{M, K};
switch (params.input_case) {
case InputCase::kFP8Current: {
A_tensors.emplace_back(make_fp8_operand("A" + std::to_string(i), a_shape));
Expand Down Expand Up @@ -247,6 +248,8 @@ void run_grouped_gemm_case(const TestParams& params) {
nullptr, // config (use defaults)
0);

NVTE_CHECK_CUDA(cudaDeviceSynchronize());
// Compare results
for (size_t i = 0; i < num_gemms; ++i) {
Tensor grouped_split("grouped_D" + std::to_string(i),
std::vector<size_t>{static_cast<size_t>(std::get<0>(shapes[i])),
Expand Down Expand Up @@ -289,7 +292,6 @@ std::string MakeGroupedGemmTestName(const testing::TestParamInfo<GroupedGemmTest
// TestParams: {input_case, transa, transb, shape_case, use_null_c}
const std::vector<TestParams> kTestParams = {
// Basic tests
{InputCase::kFP8Current, true, false, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, true, ShapeCase::kAllDifferent, false},
{InputCase::kFP8Current, false, false, ShapeCase::kAllSame, false},
{InputCase::kBF16, true, false, ShapeCase::kSameFirst, false},
Expand Down
201 changes: 200 additions & 1 deletion tests/pytorch/test_numerics.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,12 @@
is_nvfp4_available,
)
from transformer_engine.pytorch import checkpoint as te_checkpoint
from transformer_engine.pytorch.cpp_extensions import general_gemm, general_grouped_gemm
from transformer_engine.pytorch.cpp_extensions import (
general_gemm,
general_grouped_gemm,
general_grouped_gemm_for_grouped_tensor,
)
from transformer_engine.pytorch.tensor.storage.grouped_tensor import GroupedTensor
from transformer_engine.common import recipe
import transformer_engine_torch as tex
from utils import ModelConfig, reset_rng_states
Expand Down Expand Up @@ -1991,6 +1996,82 @@ def test_grouped_linear_accuracy(
torch.testing.assert_close(o, o_ref, rtol=0, atol=0)


@pytest.mark.parametrize("single_weight", [True, False], ids=["single_weight", "multi_weight"])
def test_grouped_linear_m_splits_tensor(single_weight):
"""Test GroupedLinear with m_splits as torch tensor (no_quantization/bf16).
grouped_tensor_path is chosen and must match reference (single_weight vs reference model,
or multi_weight list m_splits vs tensor m_splits).
"""
if tex.get_cublasLt_version() < 130200:
pytest.skip("Grouped GEMM requires cuBLAS 13.2+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

torch.manual_seed(0)
num_gemms = 3
in_features = 32
out_features = 64
m_splits = torch.tensor([5, 7, 9], device="cuda", dtype=torch.int64)
m_splits_list = [5, 7, 9]
dtype = torch.bfloat16
m_total = int(m_splits.sum().item())

reference_model = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=False,
params_dtype=dtype,
device="cuda",
single_weight=False,
)
with torch.no_grad():
ref_weights = [getattr(reference_model, f"weight{i}") for i in range(num_gemms)]

test_model = GroupedLinear(
num_gemms,
in_features,
out_features,
bias=False,
params_dtype=dtype,
device="cuda",
single_weight=single_weight,
)
with torch.no_grad():
if single_weight:
for i, w in enumerate(test_model.grouped_weight_storage.split_into_quantized_tensors()):
w.copy_(ref_weights[i])
else:
for i in range(num_gemms):
getattr(test_model, f"weight{i}").copy_(ref_weights[i])

inp = torch.randn(m_total, in_features, device="cuda", dtype=dtype, requires_grad=True)
inp_ref = inp.detach().clone().requires_grad_()

if single_weight:
out = test_model(inp, m_splits)
out_ref = reference_model(inp_ref, m_splits)
else:
out = test_model(inp, m_splits)
out_ref = reference_model(inp_ref, m_splits_list)

torch.testing.assert_close(out, out_ref, **dtype_tols(dtype))

out.sum().backward()
out_ref.sum().backward()

torch.testing.assert_close(inp.grad, inp_ref.grad, **dtype_tols(dtype))
if single_weight:
ref_wgrad = torch.cat(
[getattr(reference_model, f"weight{i}").grad.view(-1) for i in range(num_gemms)]
)
torch.testing.assert_close(
getattr(test_model, "weight0").grad, ref_wgrad, **dtype_tols(dtype)
)


@pytest.mark.skipif(
torch.cuda.get_device_capability() != (9, 0),
reason="Only enable CUTLASS grouped gemm on Hopper",
Expand Down Expand Up @@ -2790,6 +2871,124 @@ def test_grouped_gemm(shape, dtype, layout, accumulate, use_cutlass):
os.environ.pop("NVTE_USE_CUTLASS_GROUPED_GEMM", None)


def _pack_grouped_tensor(grouped_tensor: GroupedTensor, tensors: List[torch.Tensor]) -> None:
offset = 0
for tensor in tensors:
numel = tensor.numel()
grouped_tensor.data[offset : offset + numel].copy_(tensor.reshape(-1))
offset += numel


@pytest.mark.parametrize("layout", ["TN", "NN", "NT"])
@pytest.mark.parametrize("accumulate", [False])
def test_grouped_gemm_grouped_tensor(layout, accumulate):
if tex.get_cublasLt_version() < 130200:
pytest.skip("Grouped GEMM requires cuBLAS 13.2+.")
if torch.cuda.get_device_capability() < (10, 0):
pytest.skip("Grouped GEMM requires Blackwell (SM100) or newer.")
if not is_bf16_available():
pytest.skip("bfloat16 is required for grouped GEMM test.")

torch.manual_seed(0)
z, m, k, n = (4, 512, 256, 256)

split_points = torch.randperm(m - 1)[: z - 1] + 1
split_points = torch.sort(split_points).values.tolist()
m_sizes = [split_points[0]]
m_sizes += [b - a for a, b in zip(split_points[:-1], split_points[1:])]
m_sizes.append(m - split_points[-1])
assert sum(m_sizes) == m and len(m_sizes) == z

dtype = torch.bfloat16

if layout == "TN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input
out = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # output
grad = False

elif layout == "NN":
A = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # weight
B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output
out = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # dgrad
grad = True
else: # layout == "NT"
A = [torch.randn(ms, k, dtype=dtype, device="cuda") for ms in m_sizes] # input
B = [torch.randn(ms, n, dtype=dtype, device="cuda") for ms in m_sizes] # grad_output
out = [torch.randn(n, k, dtype=dtype, device="cuda") for _ in range(z)] # wgrad
grad = True

out_ref = [o.clone() for o in out]
general_grouped_gemm(
A,
B,
out_ref,
[None] * z,
dtype,
m_splits=m_sizes,
grad=grad,
accumulate=accumulate,
layout=layout,
single_output=False,
)

device = A[0].device

def _make_grouped_tensor_from_splits(m_sizes, last_dim):
first_dims = torch.tensor(m_sizes, device=device, dtype=torch.int64)
return GroupedTensor.make_grouped_tensor(
num_tensors=len(m_sizes),
first_dims=first_dims,
last_dims=None,
logical_first_dim=sum(m_sizes),
logical_last_dim=last_dim,
quantizer=None,
device=device,
dtype=dtype,
)

def _make_grouped_tensor_uniform(num_tensors, first_dim, last_dim):
return GroupedTensor.make_grouped_tensor(
num_tensors=num_tensors,
first_dims=None,
last_dims=None,
logical_first_dim=num_tensors * first_dim,
logical_last_dim=last_dim,
quantizer=None,
device=device,
dtype=dtype,
)

if layout == "TN":
grouped_A = _make_grouped_tensor_uniform(z, n, k)
grouped_B = _make_grouped_tensor_from_splits(m_sizes, k)
grouped_out = _make_grouped_tensor_from_splits(m_sizes, n)
elif layout == "NN":
grouped_A = _make_grouped_tensor_uniform(z, n, k)
grouped_B = _make_grouped_tensor_from_splits(m_sizes, n)
grouped_out = _make_grouped_tensor_from_splits(m_sizes, k)
else: # layout == "NT"
grouped_A = _make_grouped_tensor_from_splits(m_sizes, k)
grouped_B = _make_grouped_tensor_from_splits(m_sizes, n)
grouped_out = _make_grouped_tensor_uniform(z, n, k)
_pack_grouped_tensor(grouped_A, A)
_pack_grouped_tensor(grouped_B, B)
_pack_grouped_tensor(grouped_out, out)

general_grouped_gemm_for_grouped_tensor(
grouped_A,
grouped_B,
grouped_out,
layout=layout,
accumulate=accumulate,
)

out_grouped = grouped_out.split_into_quantized_tensors()
tols = dtype_tols(dtype)
for o, o_ref in zip(out_grouped, out_ref):
torch.testing.assert_close(o, o_ref, **tols)


@pytest.mark.parametrize("N", [32])
@pytest.mark.parametrize("datatype", [torch.float16, torch.bfloat16])
@pytest.mark.parametrize(
Expand Down
8 changes: 4 additions & 4 deletions transformer_engine/common/gemm/cublaslt_grouped_gemm.cu
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include <transformer_engine/transformer_engine.h>

#include <cstdint>
#include <vector>

#include "../common.h"
#include "../util/cuda_runtime.h"
Expand Down Expand Up @@ -138,7 +139,6 @@ struct GroupedGemmSetupWorkspace {
offset += ptr_size;
ws.beta_ptrs = reinterpret_cast<float **>(setup_ws_ptr + offset);
offset += ptr_size;

// Int arrays for storage dimensions (4-byte aligned)
ws.a_rows = reinterpret_cast<int *>(setup_ws_ptr + offset);
offset += int_size;
Expand Down Expand Up @@ -487,9 +487,9 @@ __global__ void setup_grouped_gemm_kernel(
a_cols[idx] = static_cast<int>(a_first);
b_rows[idx] = static_cast<int>(b_last);
b_cols[idx] = static_cast<int>(b_first);
// For OUTPUTS (D, C): cuBLAS writes in column-major, so rows=first (M), cols=last (N).
d_rows[idx] = static_cast<int>(d_first);
d_cols[idx] = static_cast<int>(d_last);

d_rows[idx] = static_cast<int>(d_last);
d_cols[idx] = static_cast<int>(d_first);

// Fill alpha/beta pointers (per-matrix)
alpha_ptrs[idx] = alpha_ptr + idx;
Expand Down
Loading
Loading