diff --git a/tests/pytorch/test_grouped_tensor.py b/tests/pytorch/test_grouped_tensor.py index 318009c669..ad08c0474d 100644 --- a/tests/pytorch/test_grouped_tensor.py +++ b/tests/pytorch/test_grouped_tensor.py @@ -55,7 +55,7 @@ def make_quantizer(quantization: str, num_tensors: int, shape: List[Tuple[int, int]]) -> Quantizer: - """Create quantizers for given quantization scheme""" + """Create quantizer for given quantization scheme""" if quantization == "fp8_delayed_scaling": quantizer = Float8Quantizer( @@ -203,12 +203,12 @@ def test_split_into_quantized_tensors_quantized(self, quantization: str) -> None """Test split_into_quantized_tensors for quantized tensors""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -260,12 +260,12 @@ def test_quantize_inplace(self, quantization: str) -> None: """Test that quantize is done in-place for all recipes""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -300,12 +300,12 @@ def test_quantize_varying_shapes(self, quantization: str) -> None: """Test quantize with varying shapes""" num_tensors = 3 shape = [(256, 512), (512, 512), (768, 512)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) grouped_tensor = GroupedTensor.make_grouped_tensor_with_shapes( num_tensors=num_tensors, shape=shape, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -334,7 +334,7 @@ def test_static_quantize_method(self, quantization: str) -> None: """Test the static quantize method""" num_tensors = 3 shape = [(512, 512) for _ in range(num_tensors)] - quantizers = make_quantizer(quantization, num_tensors, shape) + quantizer = make_quantizer(quantization, num_tensors, shape) # Create input tensors input_tensors = [torch.randn(s, dtype=torch.float32, device="cuda") for s in shape] @@ -342,7 +342,7 @@ def test_static_quantize_method(self, quantization: str) -> None: # Use static quantize method grouped_tensor = GroupedTensor.create_and_quantize( tensors=input_tensors, - quantizer=quantizers, + quantizer=quantizer, device="cuda", ) @@ -361,6 +361,99 @@ def test_static_quantize_method(self, quantization: str) -> None: expected_offset = _rowwise_offset_bytes(i * numel, quantization) assert rowwise_data.data_ptr() == original_data_ptr + expected_offset + @pytest.mark.parametrize( + "shape", + [[(256, 512), (512, 512), (768, 512)], [(512, 512), (512, 512), (512, 512)]], + ) + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_quantize_grouped_mxfp8(self, shape: List[Tuple[int, int]]) -> None: + """Test grouped quantization for MXFP8 against per-tensor quantization.""" + # Test wont pass until the grouped quantization PR from Oleg is merged. + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + + # Create BF16 input tensors and pack into a 2D tensor + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + quantized_tensors = [ + MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3)(tensor) for tensor in input_tensors + ] + grouped_input = torch.cat(input_tensors, dim=0) + + # Create MXFP8 output grouped tensor (rowwise only for easier validation) + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + # Quantize using grouped API + grouped_output = tex.group_quantize( + grouped_input, + quantizer, + num_tensors, + first_dims, + ) + # Build expected output by quantizing each tensor independently + expected_data = [] + expected_scale_inv = [] + for tensor in input_tensors: + qtensor = quantizer(tensor) + expected_data.append(qtensor._rowwise_data.reshape(-1)) + expected_scale_inv.append(qtensor._rowwise_scale_inv.reshape(-1)) + + expected_data = torch.cat(expected_data) + expected_scale_inv = torch.cat(expected_scale_inv) + + assert torch.equal(grouped_output.data, expected_data) + assert torch.equal(grouped_output.scale_inv, expected_scale_inv) + + @pytest.mark.skipif(not mxfp8_available, reason=reason_for_no_mxfp8) + def test_group_quantize_cudagraph_capturable(self) -> None: + """Ensure group_quantize is CUDA graph capturable.""" + num_tensors = 2 + shape = [(512, 1024) for _ in range(num_tensors)] + input_tensors = [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape] + grouped_input = torch.cat(input_tensors, dim=0) + + quantizer = MXFP8Quantizer(fp8_dtype=tex.DType.kFloat8E4M3) + quantizer.set_usage(rowwise=True, columnwise=False) + first_dims = torch.tensor( + [shape[0][0] for _ in range(num_tensors)], + dtype=torch.int64, + device="cuda", + ) + + torch.cuda.synchronize() + static_input = grouped_input.clone() + static_first_dims = first_dims.clone() + + # Warmup to initialize kernels and allocator state + _ = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + torch.cuda.synchronize() + + graph = torch.cuda.CUDAGraph() + with torch.cuda.graph(graph): + static_output = tex.group_quantize( + static_input, + quantizer, + num_tensors, + static_first_dims, + ) + + fresh_input = torch.cat( + [torch.randn(s, dtype=torch.bfloat16, device="cuda") for s in shape], + dim=0, + ) + static_input.copy_(fresh_input) + graph.replay() + torch.cuda.synchronize() + + expected = tex.group_quantize(static_input, quantizer, num_tensors, static_first_dims) + assert torch.equal(static_output.data, expected.data) + assert torch.equal(static_output.scale_inv, expected.scale_inv) + def test_clear(self) -> None: """Test clear method""" num_tensors = 3 diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index a29a09836e..aae5fc4e85 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -827,9 +827,9 @@ void group_quantize(const GroupedTensor *input, const GroupedTensor *activations "First dimension of a grouped tensor should be divisible by 128."); } - const int64_t *const offsets_ptr = reinterpret_cast(input->tensor_offsets.dptr); - const int64_t *const first_dims_ptr = reinterpret_cast(input->first_dims.dptr); - const int64_t *const last_dims_ptr = reinterpret_cast(input->last_dims.dptr); + const int64_t *const offsets_ptr = reinterpret_cast(output->tensor_offsets.dptr); + const int64_t *const first_dims_ptr = reinterpret_cast(output->first_dims.dptr); + const int64_t *const last_dims_ptr = reinterpret_cast(output->last_dims.dptr); float *const workspace_ptr = IS_DBIAS ? reinterpret_cast(workspace->data.dptr) : nullptr; float *const amax_ptr = reinterpret_cast(output->amax.dptr); diff --git a/transformer_engine/common/include/transformer_engine/transformer_engine.h b/transformer_engine/common/include/transformer_engine/transformer_engine.h index ae41f238a4..ab35a1c68c 100644 --- a/transformer_engine/common/include/transformer_engine/transformer_engine.h +++ b/transformer_engine/common/include/transformer_engine/transformer_engine.h @@ -957,8 +957,221 @@ class TensorWrapper { NVTETensor tensor_ = nullptr; }; -/*! \warning Deprecated */ -enum class Float8BlockScaleTensorFormat { GEMM_READY = 0, COMPACT = 1, INVALID }; +/*! \struct GroupedTensorWrapper + * \brief C++ wrapper for the NVTEGroupedTensor class. + */ + +class GroupedTensorWrapper { + public: + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * TE grouped tensors are just wrappers on top of raw data and do not + * own memory. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const NVTEShape &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : tensor_(nvte_create_grouped_tensor(scaling_mode, num_tensors, logical_shape)) {} + + /*! \brief Constructs new GroupedTensorWrapper. + * + * Create a new TE grouped tensor with a given logical shape. + * + * \param[in] num_tensors Number of tensors in the group (must be > 0). + * \param[in] logical_shape Logical 2D shape of the grouped data. + * \param[in] scaling_mode Tensor data format. + */ + GroupedTensorWrapper(const size_t num_tensors, const std::vector &logical_shape, + const NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING) + : GroupedTensorWrapper(num_tensors, + nvte_make_shape(logical_shape.data(), logical_shape.size()), + scaling_mode) {} + + /*! \brief GroupedTensorWrapper destructor. */ + ~GroupedTensorWrapper() { nvte_destroy_grouped_tensor(tensor_); } + + GroupedTensorWrapper &operator=(const GroupedTensorWrapper &other) = delete; + GroupedTensorWrapper(const GroupedTensorWrapper &other) = delete; + + /*! \brief Constructs new GroupedTensorWrapper from existing GroupedTensorWrapper. */ + GroupedTensorWrapper(GroupedTensorWrapper &&other) { + tensor_ = other.tensor_; + other.tensor_ = nullptr; + } + + /*! \brief Assign the data from existing GroupedTensorWrapper. */ + GroupedTensorWrapper &operator=(GroupedTensorWrapper &&other) { + if (this == &other) return *this; + nvte_destroy_grouped_tensor(tensor_); + tensor_ = other.tensor_; + other.tensor_ = nullptr; + return *this; + } + + // Parameter setters + template + GroupedTensorWrapper &set_parameter(const NVTEGroupedTensorParam param, void *dptr, DType type, + const ShapeType &shape) noexcept { + NVTEShape nvte_shape = this->convertShape(shape); + NVTEBasicTensor data = {dptr, static_cast(type), nvte_shape}; + nvte_set_grouped_tensor_param(&tensor_, param, &data); + return *this; + } + + template + GroupedTensorWrapper &set_rowwise_data(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_data(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseData, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_scale(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedScale, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_amax(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_rowwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedRowwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_scale_inv(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseScaleInv, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_columnwise_amax(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedColumnwiseAmax, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_first_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedFirstDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_last_dims(void *dptr, DType type, const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedLastDims, dptr, type, shape); + } + + template + GroupedTensorWrapper &set_tensor_offsets(void *dptr, DType type, + const ShapeType &shape) noexcept { + return set_parameter(kNVTEGroupedTensorOffsets, dptr, type, shape); + } + + // Parameter getters + NVTEBasicTensor get_parameter(const NVTEGroupedTensorParam param) const noexcept { + return nvte_get_grouped_tensor_param(tensor_, param); + } + + NVTEBasicTensor get_rowwise_data() const noexcept { + return get_parameter(kNVTEGroupedRowwiseData); + } + + NVTEBasicTensor get_columnwise_data() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseData); + } + + NVTEBasicTensor get_scale() const noexcept { return get_parameter(kNVTEGroupedScale); } + + NVTEBasicTensor get_amax() const noexcept { return get_parameter(kNVTEGroupedAmax); } + + NVTEBasicTensor get_rowwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedRowwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_scale_inv() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseScaleInv); + } + + NVTEBasicTensor get_columnwise_amax() const noexcept { + return get_parameter(kNVTEGroupedColumnwiseAmax); + } + + NVTEBasicTensor get_first_dims() const noexcept { return get_parameter(kNVTEGroupedFirstDims); } + + NVTEBasicTensor get_last_dims() const noexcept { return get_parameter(kNVTEGroupedLastDims); } + + NVTEBasicTensor get_tensor_offsets() const noexcept { + return get_parameter(kNVTEGroupedTensorOffsets); + } + + /*! \brief Get an underlying NVTEGroupedTensor. + * + * \return NVTEGroupedTensor held by this GroupedTensorWrapper. + */ + NVTEGroupedTensor data() const noexcept { return tensor_; } + + /*! \brief Get the number of tensors in this GroupedTensorWrapper. */ + size_t num_tensors() const noexcept { + if (tensor_ == nullptr) return 0; + return nvte_grouped_tensor_num_tensors(tensor_); + } + + /*! \brief Get the data type of this GroupedTensorWrapper. */ + DType dtype() const noexcept { + if (tensor_ == nullptr) return DType::kNumTypes; + return static_cast(nvte_grouped_tensor_type(tensor_)); + } + + /*! \brief Get a scaling mode of the grouped tensor. */ + NVTEScalingMode scaling_mode() const noexcept { + if (tensor_ == nullptr) return NVTE_DELAYED_TENSOR_SCALING; + return nvte_grouped_tensor_scaling_mode(tensor_); + } + + /*! \brief Get the logical shape of this GroupedTensorWrapper. */ + const NVTEShape logical_shape() const noexcept { + if (tensor_ == nullptr) { + return emptyShape; + } + return nvte_get_grouped_tensor_logical_shape(tensor_); + } + + static constexpr size_t defaultData = 1; + static constexpr NVTEShape defaultShape = { + {defaultData, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + static constexpr NVTEShape emptyShape = {{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, 1}; + + private: + NVTEShape convertShape(const NVTEShape &s) { return s; } + + NVTEShape convertShape(const std::vector &s) { + return nvte_make_shape(s.data(), s.size()); + } + + /*! \brief Wrapped NVTEGroupedTensor. */ + NVTEGroupedTensor tensor_ = nullptr; +}; + +/*! \enum Float8BlockScaleTensorFormat + * \brief Data format for an FP8 block-scaled tensor + */ +enum class Float8BlockScaleTensorFormat { + /*! FP8 data is transposed if needed and scales are swizzled */ + GEMM_READY = 0, + /*! FP8 data is untransposed and scales are not swizzled or padded */ + COMPACT = 1, + INVALID +}; /*! \struct QuantizationConfigWrapper * \brief C++ wrapper for NVTEQuantizationConfigWrapper. diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..6aab9938b3 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -103,6 +103,12 @@ class Quantizer { virtual std::pair create_tensor(const std::vector& shape, DType dtype) const = 0; + /*! @brief Construct a grouped tensor with uninitialized data */ + virtual std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const = 0; + /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * * The PyTorch tensor's attributes are modified to match the @@ -138,6 +144,11 @@ class NoneQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, at::Tensor data) const; @@ -164,6 +175,11 @@ class Float8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, @@ -196,6 +212,11 @@ class Float8CurrentScalingQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * * The amax is zeroed out. Most TE kernels that output amax expect @@ -253,6 +274,11 @@ class Float8BlockQuantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -274,6 +300,11 @@ class MXFP8Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + std::pair convert_and_update_tensor(py::object shape) const override; void quantize(const TensorWrapper& input, TensorWrapper& out, @@ -308,6 +339,11 @@ class NVFP4Quantizer : public Quantizer { std::pair create_tensor(const std::vector& shape, DType dtype) const override; + std::pair create_grouped_tensor( + size_t num_tensors, const std::vector& logical_shape, DType dtype, + py::object quantizer, const std::optional& first_dims, size_t logical_first_dim, + size_t logical_last_dim) const override; + /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * * The amax is zeroed out. Most TE kernels that output amax expect diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index e0ea3d6b78..8f6189fc8d 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -250,6 +250,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5c9d0f5b07..7976454f36 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,42 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +// NOTE: Only supports varying first dim. +py::object group_quantize(const at::Tensor &tensor, py::handle quantizer, const size_t num_tensors, + std::optional first_dims) { + using namespace transformer_engine::pytorch::detail; + init_extension(); + + NVTE_CHECK(tensor.dim() == 2, "Tensor must be 2D"); + + std::vector logical_shape; + for (const auto &d : tensor.sizes()) { + logical_shape.push_back(d); + } + const auto logical_first_dim = logical_shape[0]; + const auto logical_last_dim = logical_shape[1]; + + auto quantizer_cpp = convert_quantizer(quantizer); + + // Create input GroupedTensor. + auto grouped_input_tensor = GroupedTensorWrapper(num_tensors, logical_shape); + grouped_input_tensor.set_rowwise_data( + tensor.data_ptr(), GetTransformerEngineDType(tensor.scalar_type()), getTensorShape(tensor)); + + // Create output GroupedTensor. + auto [grouped_output_tensor_cpp, grouped_output_py] = quantizer_cpp->create_grouped_tensor( + num_tensors, logical_shape, GetTransformerEngineDType(tensor.scalar_type()), + py::reinterpret_borrow(quantizer), first_dims, logical_first_dim, + logical_last_dim); + + NVTE_SCOPED_GIL_RELEASE({ + nvte_group_quantize(grouped_input_tensor.data(), grouped_output_tensor_cpp.data(), + at::cuda::getCurrentCUDAStream()); + }); + + return py::reinterpret_borrow(grouped_output_py); +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 1e907d9bc0..5e9eccced0 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -35,6 +35,7 @@ PyTypeObject *Float8BlockwiseQuantizerClass = nullptr; PyTypeObject *NVFP4TensorPythonClass = nullptr; PyTypeObject *NVFP4TensorStoragePythonClass = nullptr; PyTypeObject *NVFP4QuantizerClass = nullptr; +PyTypeObject *GroupedTensorStoragePythonClass = nullptr; void init_float8_extension() { if (Float8TensorPythonClass) return; @@ -104,11 +105,22 @@ void init_nvfp4_extensions() { "Internal error: could not initialize pyTorch NVFP4 extension."); } +void init_grouped_tensor_extension() { + if (GroupedTensorStoragePythonClass) return; + auto grouped_tensor_module = + py::module_::import("transformer_engine.pytorch.tensor.storage.grouped_tensor"); + GroupedTensorStoragePythonClass = reinterpret_cast( + PyObject_GetAttrString(grouped_tensor_module.ptr(), "GroupedTensor")); + NVTE_CHECK(GroupedTensorStoragePythonClass != nullptr, + "Internal error: could not initialize pyTorch grouped tensor extension."); +} + void init_extension() { init_float8_extension(); init_mxfp8_extension(); init_float8blockwise_extension(); init_nvfp4_extensions(); + init_grouped_tensor_extension(); } } // namespace transformer_engine::pytorch @@ -121,7 +133,8 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); - + m.def("group_quantize", transformer_engine::pytorch::group_quantize, py::arg("tensor"), + py::arg("quantizer"), py::arg("num_tensors"), py::arg("first_dims")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); m.def("generic_gemm", transformer_engine::pytorch::gemm, "Compute GEMM (matrix-matrix multiply)", diff --git a/transformer_engine/pytorch/csrc/pybind.h b/transformer_engine/pytorch/csrc/pybind.h index 25ffef0588..059eb5e3fb 100644 --- a/transformer_engine/pytorch/csrc/pybind.h +++ b/transformer_engine/pytorch/csrc/pybind.h @@ -43,6 +43,7 @@ extern PyTypeObject *Float8BlockwiseQuantizerClass; extern PyTypeObject *NVFP4TensorPythonClass; extern PyTypeObject *NVFP4TensorStoragePythonClass; extern PyTypeObject *NVFP4QuantizerClass; +extern PyTypeObject *GroupedTensorStoragePythonClass; void init_extension(); @@ -95,6 +96,8 @@ TensorWrapper NVTETensorFromFloat8BlockwiseQTensor(py::handle tensor, TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer); +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor); + inline bool IsFloatingPointType(at::ScalarType type) { return type == at::kFloat || type == at::kHalf || type == at::kBFloat16; } diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..b6c8ba5e74 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -42,6 +42,35 @@ std::vector convert_shape_for_fp4(const std::vector& shape) { return ret; } +std::optional build_grouped_tensor_offsets(const size_t num_tensors, + const std::optional& first_dims, + const size_t logical_last_dim) { + if (!first_dims.has_value()) { + return std::nullopt; + } + + const auto& first_dims_tensor = first_dims.value(); + NVTE_CHECK(first_dims_tensor.scalar_type() == at::kLong, "first_dims must have dtype int64."); + NVTE_CHECK(static_cast(first_dims_tensor.numel()) == num_tensors, + "first_dims must have length ", num_tensors, "."); + + const int64_t logical_last_dim_i64 = static_cast(logical_last_dim); + auto scaled_first_dims = first_dims_tensor * logical_last_dim_i64; + + // Single kernel needed for these ops. + auto cumsum = at::cumsum(scaled_first_dims, 0); + auto zero = at::zeros({1}, cumsum.options()); + return at::cat({zero, cumsum}); +} + +at::TensorOptions grouped_tensor_data_options(const DType dtype) { + return at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); +} + +py::object maybe_tensor_to_py(const std::optional& tensor) { + return tensor ? py::cast(*tensor) : py::none(); +} + } // namespace constexpr size_t NVFP4_BLOCK_SIZE = 16; @@ -88,6 +117,60 @@ std::pair NoneQuantizer::create_tensor(const std::vec return {std::move(out_cpp), py::cast(data)}; } +std::pair NoneQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + std::optional rowwise_data; + std::optional columnwise_data; + const bool with_rowwise_data = rowwise_usage; + const bool with_columnwise_data = columnwise_usage; + if (with_rowwise_data) { + rowwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + if (with_columnwise_data) { + columnwise_data = at::empty({total_elements}, grouped_tensor_data_options(dtype)); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (with_rowwise_data) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), dtype, getTensorShape(*rowwise_data)); + } + if (with_columnwise_data) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), dtype, + getTensorShape(*columnwise_data)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), "scale_inv"_a = py::none(), + "columnwise_scale_inv"_a = py::none(), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NoneQuantizer::convert_and_update_tensor( py::object tensor) const { auto tensor_pyt = tensor.cast(); @@ -184,6 +267,73 @@ std::pair Float8Quantizer::create_tensor( return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsFloat8Tensor(tensor.ptr()), "Float8Quantizer must output to Float8Tensor."); @@ -390,6 +540,75 @@ std::pair Float8CurrentScalingQuantizer::create_tenso return {std::move(out_cpp), std::move(out_py)}; } +std::pair Float8CurrentScalingQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + at::Tensor scale = at::empty({static_cast(num_tensors)}, float_opts); + at::Tensor amax = at::empty({static_cast(num_tensors)}, float_opts); + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + rowwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + columnwise_scale_inv = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + out_cpp.set_scale(scale.data_ptr(), DType::kFloat32, getTensorShape(scale)); + out_cpp.set_amax(amax.data_ptr(), DType::kFloat32, getTensorShape(amax)); + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = amax, + "columnwise_amax"_a = py::none(), "scale"_a = scale, + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8CurrentScalingQuantizer::create_unquantized_tensor_with_amax(const std::vector& shape, DType dtype, @@ -638,6 +857,77 @@ std::pair Float8BlockQuantizer::create_tensor( return {std::move(tensor), std::move(ret)}; } +std::pair Float8BlockQuantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair Float8BlockQuantizer::convert_and_update_tensor( py::object tensor) const { const DType dtype = tensor.attr("_fp8_dtype").cast(); @@ -940,6 +1230,76 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair MXFP8Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + if (rowwise_usage) { + rowwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*rowwise_scale_inv)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E8M0, + getTensorShape(*columnwise_scale_inv)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), "amax"_a = py::none(), + "columnwise_amax"_a = py::none(), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair MXFP8Quantizer::convert_and_update_tensor( py::object tensor) const { NVTE_CHECK(detail::IsMXFP8Tensor(tensor.ptr()), "MXFP8Quantizer must output to MXFP8Tensor."); @@ -1240,6 +1600,88 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve return {std::move(out_cpp), std::move(out_py)}; } +std::pair NVFP4Quantizer::create_grouped_tensor( + const size_t num_tensors, const std::vector& logical_shape, const DType dtype, + py::object quantizer, const std::optional& first_dims, + const size_t logical_first_dim, const size_t logical_last_dim) const { + using namespace pybind11::literals; + + const auto tensor_offsets = + build_grouped_tensor_offsets(num_tensors, first_dims, logical_last_dim); + const int64_t total_elements = + static_cast(logical_first_dim) * static_cast(logical_last_dim); + NVTE_CHECK(total_elements % 2 == 0, "NVFP4 data size must be divisible by 2."); + + const auto uint8_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto float_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + + std::optional rowwise_data; + std::optional columnwise_data; + std::optional rowwise_scale_inv; + std::optional columnwise_scale_inv; + std::optional rowwise_amax; + std::optional columnwise_amax; + const std::vector logical_shape_vec = {logical_first_dim, logical_last_dim}; + + const int64_t total_data_elements = total_elements / 2; + + if (rowwise_usage) { + rowwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, false); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + rowwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + rowwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + if (columnwise_usage) { + columnwise_data = at::empty({total_data_elements}, uint8_opts); + const auto scale_shape = get_scale_shape(logical_shape_vec, true); + const int64_t total_scale_elements = static_cast(product(scale_shape)); + columnwise_scale_inv = at::empty({total_scale_elements}, uint8_opts); + columnwise_amax = at::empty({static_cast(num_tensors)}, float_opts); + } + + GroupedTensorWrapper out_cpp(num_tensors, logical_shape, this->get_scaling_mode()); + if (rowwise_usage) { + out_cpp.set_rowwise_data(rowwise_data->data_ptr(), this->dtype, getTensorShape(*rowwise_data)); + out_cpp.set_rowwise_scale_inv(rowwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*rowwise_scale_inv)); + out_cpp.set_amax(rowwise_amax->data_ptr(), DType::kFloat32, getTensorShape(*rowwise_amax)); + } + if (columnwise_usage) { + out_cpp.set_columnwise_data(columnwise_data->data_ptr(), this->dtype, + getTensorShape(*columnwise_data)); + out_cpp.set_columnwise_scale_inv(columnwise_scale_inv->data_ptr(), DType::kFloat8E4M3, + getTensorShape(*columnwise_scale_inv)); + out_cpp.set_columnwise_amax(columnwise_amax->data_ptr(), DType::kFloat32, + getTensorShape(*columnwise_amax)); + } + if (first_dims.has_value()) { + out_cpp.set_first_dims(first_dims->data_ptr(), DType::kInt64, getTensorShape(*first_dims)); + } + if (tensor_offsets.has_value()) { + out_cpp.set_tensor_offsets(tensor_offsets->data_ptr(), DType::kInt64, + getTensorShape(*tensor_offsets)); + } + + py::handle GroupedTensorClass(reinterpret_cast(GroupedTensorStoragePythonClass)); + py::object out_py = GroupedTensorClass( + "num_tensors"_a = num_tensors, "quantizer"_a = std::move(quantizer), + "dtype"_a = GetATenDType(dtype), "data"_a = maybe_tensor_to_py(rowwise_data), + "columnwise_data"_a = maybe_tensor_to_py(columnwise_data), + "scale_inv"_a = maybe_tensor_to_py(rowwise_scale_inv), + "columnwise_scale_inv"_a = maybe_tensor_to_py(columnwise_scale_inv), + "amax"_a = maybe_tensor_to_py(rowwise_amax), + "columnwise_amax"_a = maybe_tensor_to_py(columnwise_amax), "scale"_a = py::none(), + "first_dims"_a = first_dims.has_value() ? py::cast(*first_dims) : py::none(), + "last_dims"_a = py::none(), + "tensor_offsets"_a = tensor_offsets.has_value() ? py::cast(*tensor_offsets) : py::none(), + "logical_shape"_a = std::vector{static_cast(logical_first_dim), + static_cast(logical_last_dim)}); + + return {std::move(out_cpp), std::move(out_py)}; +} + std::pair NVFP4Quantizer::create_unquantized_tensor_with_amax( TensorWrapper& quantized_tensor, DType dtype) { // Construct tensor diff --git a/transformer_engine/pytorch/csrc/type_converters.cpp b/transformer_engine/pytorch/csrc/type_converters.cpp index 3f998bb66f..eda5e8fc54 100644 --- a/transformer_engine/pytorch/csrc/type_converters.cpp +++ b/transformer_engine/pytorch/csrc/type_converters.cpp @@ -170,6 +170,121 @@ TensorWrapper NVTETensorFromNVFP4Tensor(py::handle tensor, Quantizer *quantizer) return ret; } +NVTEScalingMode ScalingModeFromQuantizer(py::handle quantizer) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return NVTE_MXFP8_1D_SCALING; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return NVTE_NVFP4_1D_SCALING; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + const int block_scaling_dim = quantizer.attr("block_scaling_dim").cast(); + return (block_scaling_dim == 2) ? NVTE_BLOCK_SCALING_2D : NVTE_BLOCK_SCALING_1D; + } + return NVTE_DELAYED_TENSOR_SCALING; +} + +DType GetTransformerEngineDTypeForScaleInv(py::handle quantizer, at::Tensor scale_inv) { + auto *quantizer_ptr = quantizer.ptr(); + if (IsMXFP8Quantizers(quantizer_ptr)) { + return DType::kFloat8E8M0; + } + if (IsFloat8BlockwiseQuantizers(quantizer_ptr)) { + return DType::kFloat32; + } + if (IsNVFP4Quantizers(quantizer_ptr)) { + return DType::kFloat8E4M3; + } + return GetTransformerEngineDType(scale_inv.scalar_type()); +} + +GroupedTensorWrapper GroupedTensorFromPyTorchGroupedTensor(py::handle tensor) { + // Returns a GroupedTensorWrapper from a PyTorch GroupedTensor. + const auto num_tensors = tensor.attr("num_tensors").cast(); + const auto logical_shape = tensor.attr("logical_shape").cast>(); + py::handle quantizer = py::none(); + DType quantizer_dtype = DType::kNumTypes; + NVTEScalingMode scaling_mode = NVTE_DELAYED_TENSOR_SCALING; + if (!tensor.attr("quantizer").is_none()) { + quantizer = tensor.attr("quantizer"); + if (!quantizer.is_none()) { + scaling_mode = ScalingModeFromQuantizer(quantizer); + quantizer_dtype = quantizer.attr("dtype").cast(); + } + } + auto ret = GroupedTensorWrapper(num_tensors, logical_shape, scaling_mode); + + // Rowwise data + if (!tensor.attr("data").is_none()) { + const auto &data = tensor.attr("data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_rowwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Columnwise data + if (!tensor.attr("columnwise_data").is_none()) { + const auto &data = tensor.attr("columnwise_data").cast(); + DType data_dtype = + quantizer.is_none() ? GetTransformerEngineDType(data.scalar_type()) : quantizer_dtype; + ret.set_columnwise_data(data.data_ptr(), data_dtype, getTensorShape(data)); + } + + // Scale + if (!tensor.attr("scale").is_none()) { + const auto &scale = tensor.attr("scale").cast(); + ret.set_scale(scale.data_ptr(), GetTransformerEngineDType(scale.scalar_type()), + getTensorShape(scale)); + } + + // Amax + if (!tensor.attr("amax").is_none()) { + const auto &amax = tensor.attr("amax").cast(); + ret.set_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + if (!tensor.attr("columnwise_amax").is_none()) { + const auto &amax = tensor.attr("columnwise_amax").cast(); + ret.set_columnwise_amax(amax.data_ptr(), GetTransformerEngineDType(amax.scalar_type()), + getTensorShape(amax)); + } + + // Scale inverse + if (!tensor.attr("scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("scale_inv").cast(); + ret.set_rowwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + if (!tensor.attr("columnwise_scale_inv").is_none()) { + const auto &scale_inv = tensor.attr("columnwise_scale_inv").cast(); + ret.set_columnwise_scale_inv(scale_inv.data_ptr(), + GetTransformerEngineDTypeForScaleInv(quantizer, scale_inv), + getTensorShape(scale_inv)); + } + + // Shape metadata + if (!tensor.attr("first_dims").is_none()) { + const auto &first_dims = tensor.attr("first_dims").cast(); + ret.set_first_dims(first_dims.data_ptr(), GetTransformerEngineDType(first_dims.scalar_type()), + getTensorShape(first_dims)); + } + if (!tensor.attr("last_dims").is_none()) { + const auto &last_dims = tensor.attr("last_dims").cast(); + ret.set_last_dims(last_dims.data_ptr(), GetTransformerEngineDType(last_dims.scalar_type()), + getTensorShape(last_dims)); + } + if (!tensor.attr("tensor_offsets").is_none()) { + const auto &tensor_offsets = tensor.attr("tensor_offsets").cast(); + ret.set_tensor_offsets(tensor_offsets.data_ptr(), + GetTransformerEngineDType(tensor_offsets.scalar_type()), + getTensorShape(tensor_offsets)); + } + + return ret; +} + } // namespace detail } // namespace transformer_engine::pytorch diff --git a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py index dad4d1d0ea..e60007f05a 100644 --- a/transformer_engine/pytorch/tensor/storage/grouped_tensor.py +++ b/transformer_engine/pytorch/tensor/storage/grouped_tensor.py @@ -52,7 +52,7 @@ class GroupedTensor: def __init__( self, num_tensors: int, - shape: List[Tuple[int, int]], + shape: Optional[List[Tuple[int, int]]] = None, quantizer: Optional[Quantizer] = None, dtype: Optional[torch.dtype] = None, data: Optional[torch.Tensor] = None, @@ -245,6 +245,7 @@ def clear(self) -> None: """ Reset tensor data and clear all buffers. """ + self.shape = None self.data = None self.columnwise_data = None self.scale_inv = None @@ -623,6 +624,8 @@ def split_into_quantized_tensors( no_quantization = self.quantizer is None + assert self.shape is not None, "Shape must be set for splitting a GroupedTensor." + # Case 1: No quantization - return regular torch tensors if no_quantization: for i in range(self.num_tensors):