diff --git a/transformer_engine/pytorch/csrc/common.h b/transformer_engine/pytorch/csrc/common.h index bc22e03097..fe6cac4650 100644 --- a/transformer_engine/pytorch/csrc/common.h +++ b/transformer_engine/pytorch/csrc/common.h @@ -101,7 +101,9 @@ class Quantizer { /*! @brief Construct a tensor with uninitialized data */ virtual std::pair create_tensor(const std::vector& shape, - DType dtype) const = 0; + DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const = 0; /*! @brief Convert a PyTorch tensor into a Transformer Engine C++ tensor * @@ -135,8 +137,9 @@ class NoneQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override {} - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, @@ -161,14 +164,17 @@ class Float8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; /*! @brief Construct a tensor with pre-initialized data */ std::pair create_tensor(const std::vector& shape, DType dtype, std::optional data, std::optional transpose, - std::optional scale_inv) const; + std::optional scale_inv, + at::Device device = torch::kCUDA, + bool pin_memory = false) const; std::pair convert_and_update_tensor(py::object shape) const override; @@ -193,8 +199,9 @@ class Float8CurrentScalingQuantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; /*! @brief Construct an unquantized tensor that shares the quantizer's amax pointer. * @@ -250,8 +257,9 @@ class Float8BlockQuantizer : public Quantizer { // Create a python Float8BlockQuantized tensor and C++ wrapper // for the tensor. Should set quantized data, scales for rowwise // and optionally columnwise usage. - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -271,8 +279,9 @@ class MXFP8Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; std::pair convert_and_update_tensor(py::object shape) const override; @@ -305,8 +314,9 @@ class NVFP4Quantizer : public Quantizer { void set_quantization_params(TensorWrapper* tensor) const override; - std::pair create_tensor(const std::vector& shape, - DType dtype) const override; + std::pair create_tensor(const std::vector& shape, DType dtype, + at::Device device = torch::kCUDA, + bool pin_memory = false) const override; /*! @brief Construct an unquantized tensor that shares NVFP4 tensor's amax pointer * diff --git a/transformer_engine/pytorch/csrc/extensions.h b/transformer_engine/pytorch/csrc/extensions.h index f7cf32eaf6..036f4aef2c 100644 --- a/transformer_engine/pytorch/csrc/extensions.h +++ b/transformer_engine/pytorch/csrc/extensions.h @@ -250,6 +250,9 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob py::object dequantize(const py::handle &input, DType otype); +py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector &shape, + at::ScalarType dtype, at::Device device, bool pin_memory); + std::vector multi_tensor_quantize(const std::vector &tensor_list, std::vector quantizer_list); diff --git a/transformer_engine/pytorch/csrc/extensions/cast.cpp b/transformer_engine/pytorch/csrc/extensions/cast.cpp index 5c9d0f5b07..4a566efe28 100644 --- a/transformer_engine/pytorch/csrc/extensions/cast.cpp +++ b/transformer_engine/pytorch/csrc/extensions/cast.cpp @@ -80,6 +80,14 @@ py::object quantize(const at::Tensor &tensor, py::handle quantizer, const py::ob return output_py; } +py::object create_empty_quantized_tensor(py::handle quantizer, const std::vector &shape, + at::ScalarType dtype, at::Device device, bool pin_memory) { + auto quantizer_cpp = convert_quantizer(quantizer); + auto te_dtype = GetTransformerEngineDType(dtype); + auto [_, output_py] = quantizer_cpp->create_tensor(shape, te_dtype, device, pin_memory); + return output_py; +} + py::object dequantize(const py::handle &input, transformer_engine::DType otype) { init_extension(); diff --git a/transformer_engine/pytorch/csrc/extensions/pybind.cpp b/transformer_engine/pytorch/csrc/extensions/pybind.cpp index 79dd9ea5ce..9da123867d 100644 --- a/transformer_engine/pytorch/csrc/extensions/pybind.cpp +++ b/transformer_engine/pytorch/csrc/extensions/pybind.cpp @@ -121,6 +121,10 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) { py::arg("output") = py::none(), py::arg("noop") = py::none()); m.def("dequantize", &transformer_engine::pytorch::dequantize, "Dequantize", py::arg("input"), py::arg("otype")); + m.def("create_empty_quantized_tensor", + &transformer_engine::pytorch::create_empty_quantized_tensor, + "Create an empty quantized tensor", py::arg("quantizer"), py::arg("shape"), + py::arg("dtype"), py::arg("device"), py::arg("pin_memory")); m.def("bgrad_quantize", transformer_engine::pytorch::bgrad_quantize, "Compute bias gradient and quantize", py::arg("input"), py::arg("quantizer")); diff --git a/transformer_engine/pytorch/csrc/quantizer.cpp b/transformer_engine/pytorch/csrc/quantizer.cpp index 1c968e276d..dc1f6ebae4 100644 --- a/transformer_engine/pytorch/csrc/quantizer.cpp +++ b/transformer_engine/pytorch/csrc/quantizer.cpp @@ -73,9 +73,11 @@ Float8Quantizer::Float8Quantizer(const py::handle& quantizer) : Quantizer(quanti } std::pair NoneQuantizer::create_tensor(const std::vector& shape, - DType dtype) const { + DType dtype, at::Device device, + bool pin_memory) const { const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(GetATenDType(dtype)).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(GetATenDType(dtype)).device(device).pinned_memory(pin_memory); return create_tensor(shape, dtype, at::empty(shape_int64, opts)); } @@ -113,22 +115,26 @@ void Float8Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair Float8Quantizer::create_tensor( - const std::vector& shape, DType dtype) const { - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const std::vector& shape, DType dtype, at::Device device, bool pin_memory) const { + const auto opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); at::Tensor scale_inv = at::empty(std::vector{1}, opts); - return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv)); + return create_tensor(shape, dtype, std::nullopt, std::nullopt, std::move(scale_inv), device, + pin_memory); } std::pair Float8Quantizer::create_tensor( const std::vector& shape, DType dtype, std::optional data, - std::optional transpose, std::optional scale_inv) const { + std::optional transpose, std::optional scale_inv, at::Device device, + bool pin_memory) const { using namespace pybind11::literals; // Initialize data tensor const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); if (with_data && !data) { const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); data = at::empty(shape_int64, opts); } else if (!with_data && data) { data.reset(); @@ -139,7 +145,8 @@ std::pair Float8Quantizer::create_tensor( const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (with_transpose && !transpose) { const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); transpose = at::empty(transpose_shape, opts); } else if (!with_transpose && transpose) { transpose.reset(); @@ -325,7 +332,7 @@ void Float8CurrentScalingQuantizer::set_quantization_params(TensorWrapper* tenso } std::pair Float8CurrentScalingQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const std::vector& shape, DType dtype, at::Device device, bool pin_memory) const { using namespace pybind11::literals; // Initialize data tensor @@ -333,7 +340,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const bool with_data = rowwise_usage || nvte_is_non_tn_fp8_gemm_supported(); if (with_data) { const std::vector shape_int64(shape.begin(), shape.end()); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); data_tensor = at::empty(shape_int64, opts); } @@ -342,7 +350,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso const bool with_transpose = columnwise_usage && !nvte_is_non_tn_fp8_gemm_supported(); if (with_transpose) { const auto transpose_shape = make_transpose_shape(shape); - const auto opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); transpose_tensor = at::empty(transpose_shape, opts); } @@ -350,7 +359,8 @@ std::pair Float8CurrentScalingQuantizer::create_tenso at::Tensor scale_inv_tensor; { const std::vector scale_inv_shape = {1}; - const auto opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const auto opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); scale_inv_tensor = at::empty(scale_inv_shape, opts); } @@ -562,7 +572,7 @@ Float8BlockQuantizer::Float8BlockQuantizer(const py::handle& quantizer) : Quanti void Float8BlockQuantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair Float8BlockQuantizer::create_tensor( - const std::vector& shape, DType dtype) const { + const std::vector& shape, DType dtype, at::Device device, bool pin_memory) const { using namespace pybind11::literals; std::vector torch_shape; for (auto s : shape) { @@ -573,8 +583,8 @@ std::pair Float8BlockQuantizer::create_tensor( at::TensorOptions opts; at::TensorOptions scale_opts; at::Tensor data_rowwise, data_colwise, scale_inv_rowwise, scale_inv_colwise; - opts = opts.dtype(torch::kUInt8).device(torch::kCUDA); - scale_opts = scale_opts.dtype(torch::kFloat32).device(torch::kCUDA); + opts = opts.dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); + scale_opts = scale_opts.dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); if (rowwise_usage) { data_rowwise = at::empty(torch_shape, opts); @@ -858,7 +868,8 @@ MXFP8Quantizer::MXFP8Quantizer(const py::handle& quantizer) : Quantizer(quantize void MXFP8Quantizer::set_quantization_params(TensorWrapper* tensor) const {} std::pair MXFP8Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { + DType dtype, at::Device device, + bool pin_memory) const { using namespace pybind11::literals; // Scaling factor format @@ -882,7 +893,8 @@ std::pair MXFP8Quantizer::create_tensor(const std::ve // Allocate tensors at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor; at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor; - const auto uint8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); + const auto uint8_tensor_opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); @@ -1132,7 +1144,8 @@ void NVFP4Quantizer::set_quantization_params(TensorWrapper* tensor) const { } std::pair NVFP4Quantizer::create_tensor(const std::vector& shape, - DType dtype) const { + DType dtype, at::Device device, + bool pin_memory) const { using namespace pybind11::literals; // Scaling factor format @@ -1158,8 +1171,10 @@ std::pair NVFP4Quantizer::create_tensor(const std::ve // Allocate tensors at::Tensor rowwise_data_tensor, rowwise_scale_inv_tensor, amax_rowwise; at::Tensor columnwise_data_tensor, columnwise_scale_inv_tensor, amax_columnwise; - const auto bit8_tensor_opts = at::TensorOptions().dtype(torch::kUInt8).device(torch::kCUDA); - const auto bit32_tensor_opts = at::TensorOptions().dtype(torch::kFloat32).device(torch::kCUDA); + const auto bit8_tensor_opts = + at::TensorOptions().dtype(torch::kUInt8).device(device).pinned_memory(pin_memory); + const auto bit32_tensor_opts = + at::TensorOptions().dtype(torch::kFloat32).device(device).pinned_memory(pin_memory); if (rowwise_usage) { const std::vector scale_inv_shape_int64(rowwise_scale_inv_shape.begin(), rowwise_scale_inv_shape.end()); diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 0a6ad61ff0..84ec58ada2 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -13,6 +13,8 @@ import torch from torch.utils._pytree import tree_map +import transformer_engine_torch as tex + from transformer_engine.common.recipe import Recipe from transformer_engine.pytorch.tensor._quantization_helpers import ( _QuantizeFunc, @@ -272,13 +274,26 @@ def make_empty( shape: Iterable[int], *, dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, + device: Optional[Union[torch.device, str]] = None, + requires_grad: bool = False, + pin_memory: bool = False, ) -> QuantizedTensor: """Construct quantized tensor with uninitialized data""" - raise NotImplementedError( - f"{self.__class__.__name__} class does not implement make_empty function, " - "required for construction of unintialized quantized tensor" + + if device is None: + device = torch.device("cuda") + # Handle the device passed as string + device = torch.device(device) + result = tex.create_empty_quantized_tensor( + self, + list(shape), + dtype, + device, + pin_memory, ) + if requires_grad: + result.requires_grad_(True) + return result def calibrate(self, tensor: torch.Tensor) -> None: """Calibrate quantizer state diff --git a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py index ecafb6ddfc..5d9ebb2f96 100644 --- a/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_blockwise_tensor.py @@ -202,62 +202,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: return False return True - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8BlockwiseQTensor: - """Construct quantized tensor with uninitialized data""" - - tensor_kwargs = { - "device": torch.device("cuda") if device is None else device, - "pin_memory": pin_memory, - } - - # Allocate buffers for row-scaled data - rowwise_data = None - rowwise_scale_inv = None - if self.rowwise_usage: - rowwise_data = torch.empty(shape, dtype=torch.uint8, **tensor_kwargs) - rowwise_scale_inv = torch.empty( - self.get_scale_shape(shape, columnwise=False), - dtype=torch.float32, - **tensor_kwargs, - ) - - # Allocate buffers for column-scaled data - columnwise_data = None - columnwise_scale_inv = None - if self.columnwise_usage: - columnwise_data = torch.empty( - self.get_columnwise_shape(shape), - dtype=torch.uint8, - **tensor_kwargs, - ) - columnwise_scale_inv = torch.empty( - self.get_scale_shape(shape, columnwise=True), - dtype=torch.float32, - **tensor_kwargs, - ) - - # Construct FP8 tensor - return Float8BlockwiseQTensor( - shape=shape, - dtype=dtype, - fp8_dtype=self.dtype, - rowwise_data=rowwise_data, - rowwise_scale_inv=rowwise_scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - quantizer=self, - is_2D_scaled=self.block_scaling_dim == 2, - requires_grad=requires_grad, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # NOTE: This interface is specific to requirements like delayed scaling # where state from an estimator influences distribution parameters. diff --git a/transformer_engine/pytorch/tensor/float8_tensor.py b/transformer_engine/pytorch/tensor/float8_tensor.py index 3aeace0a77..b8b0de10c4 100644 --- a/transformer_engine/pytorch/tensor/float8_tensor.py +++ b/transformer_engine/pytorch/tensor/float8_tensor.py @@ -108,48 +108,6 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - - # Allocate FP8 data transpose if needed - data_transpose = None - if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) - data_transpose = torch.empty( - transpose_shape, - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Construct FP8 tensor - return Float8Tensor( - shape=shape, - dtype=dtype, - data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), - fp8_dtype=self.dtype, - requires_grad=requires_grad, - data_transpose=data_transpose, - quantizer=self, - ) - def calibrate(self, tensor: torch.Tensor) -> None: amin, amax = tensor.aminmax() self.amax.copy_(torch.max(-amin, amax)) @@ -325,47 +283,6 @@ def quantize_impl(self, tensor: torch.Tensor) -> QuantizedTensor: """Quantize tensor implementation""" return tex.quantize(tensor, self) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> Float8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - # Allocate FP8 data - data = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - - # Allocate FP8 data transpose if needed - data_transpose = None - if self.columnwise_usage: - transpose_shape = [shape[-1]] + list(shape[:-1]) - data_transpose = torch.empty( - transpose_shape, - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - # Construct FP8 tensor - return Float8Tensor( - shape=shape, - dtype=dtype, - data=data, - fp8_scale_inv=torch.empty(1, dtype=torch.float32, device=device, pin_memory=pin_memory), - fp8_dtype=self.dtype, - requires_grad=requires_grad, - data_transpose=data_transpose, - quantizer=self, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # current scaling don't need to calibrate return diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index 8dd2255d89..09a411060a 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -16,7 +16,7 @@ from transformer_engine.common.recipe import MXFP8BlockScaling, Recipe from ..constants import MXFP8_BLOCK_SCALING_SIZE -from ..utils import devices_match, round_up_to_nearest_multiple +from ..utils import devices_match from .storage.mxfp8_tensor_storage import MXFP8TensorStorage, _FromMXFP8Func from ..quantized_tensor import QuantizedTensor, Quantizer from ._quantization_helpers import _IdentityFunc @@ -96,70 +96,6 @@ def is_quantizable(self, inp: torch.Tensor) -> bool: return False return True - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - requires_grad: bool = False, - pin_memory: bool = False, - ) -> MXFP8Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - assert ( - shape[-1] % MXFP8_BLOCK_SCALING_SIZE == 0 - and math.prod(shape[:-1]) % MXFP8_BLOCK_SCALING_SIZE == 0 - ), ( - f"Incorrect shape {shape} for MXFP8. Tensor dims must be divisible by" - f" {MXFP8_BLOCK_SCALING_SIZE}" - ) - - # Allocate FP8 data - data = None - scale_inv = None - if self.rowwise_usage: - data = torch.empty(shape, dtype=torch.uint8, device=device, pin_memory=pin_memory) - scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]), 128), - round_up_to_nearest_multiple(shape[-1] // MXFP8_BLOCK_SCALING_SIZE, 4), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Allocate FP8 data transpose if needed - columnwise_data = None - columnwise_scale_inv = None - if self.columnwise_usage: - columnwise_data = torch.empty( - shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - columnwise_scale_inv = torch.empty( - round_up_to_nearest_multiple(math.prod(shape[:-1]) // MXFP8_BLOCK_SCALING_SIZE, 4), - round_up_to_nearest_multiple(shape[-1], 128), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - - # Construct FP8 tensor - return MXFP8Tensor( - shape=shape, - dtype=dtype, - fp8_dtype=self.dtype, - rowwise_data=data, - rowwise_scale_inv=scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - quantizer=self, - requires_grad=requires_grad, - with_gemm_swizzled_scales=self.optimize_for_gemm, - ) - def calibrate(self, tensor: torch.Tensor) -> None: # TODO(ksivamani): No calibration needed for mxfp8? pass diff --git a/transformer_engine/pytorch/tensor/nvfp4_tensor.py b/transformer_engine/pytorch/tensor/nvfp4_tensor.py index 101cf78a8f..86b3a55641 100644 --- a/transformer_engine/pytorch/tensor/nvfp4_tensor.py +++ b/transformer_engine/pytorch/tensor/nvfp4_tensor.py @@ -282,87 +282,6 @@ def convert_shape_for_fp4(shape: Iterable[int]) -> Tuple[int, ...]: shape[-1] = shape[-1] // 2 return tuple(shape) - def make_empty( - self, - shape: Iterable[int], - *, - dtype: torch.dtype = torch.float32, - device: Optional[torch.device] = None, - pin_memory: bool = False, - requires_grad: bool = False, - ) -> NVFP4Tensor: - - # Canonicalize tensor attributes - if device is None: - device = torch.device("cuda") - - assert shape[-1] % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - - flat_first_dim = math.prod(shape[:-1]) - assert flat_first_dim % NVFP4_BLOCK_SCALING_SIZE == 0, ( - f"Incorrect shape {shape} for NVFP4. Tensor dims must be divisible by" - f" {NVFP4_BLOCK_SCALING_SIZE}" - ) - - # Allocate FP4 data - data = None - scale_inv = None - amax_rowwise = None - if self.rowwise_usage: - data = torch.empty( - self.convert_shape_for_fp4(shape), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - scale_shape = self.get_scale_shape(shape, columnwise=False) - scale_inv = torch.empty( - scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - # Allocate per tensor scale inverse. FP32 format. - amax_rowwise = torch.zeros(1, dtype=torch.float32, device=device, pin_memory=pin_memory) - - # Allocate FP8 data transpose if needed - columnwise_data = None - columnwise_scale_inv = None - amax_columnwise = None - if self.columnwise_usage: - # enforce 2D shape to avoid [S, B, H] shape and B and be 1 - # and the transposed shape is [H, S, B], so divide last dim by 2 gives zero - shape_2d = tuple([flat_first_dim, shape[-1]]) - columnwise_data = torch.empty( - self.convert_shape_for_fp4(self.get_columnwise_shape(shape_2d)), - dtype=torch.uint8, - device=device, - pin_memory=pin_memory, - ) - columnwise_scale_shape = self.get_scale_shape(shape, columnwise=True) - columnwise_scale_inv = torch.empty( - columnwise_scale_shape, dtype=torch.uint8, device=device, pin_memory=pin_memory - ) - amax_columnwise = torch.zeros( - 1, dtype=torch.float32, device=device, pin_memory=pin_memory - ) - - # Construct FP8 tensor - return NVFP4Tensor( - shape=shape, - dtype=dtype, - rowwise_data=data, - rowwise_scale_inv=scale_inv, - columnwise_data=columnwise_data, - columnwise_scale_inv=columnwise_scale_inv, - amax_rowwise=amax_rowwise, - amax_columnwise=amax_columnwise, - fp4_dtype=self.dtype, - quantizer=self, - requires_grad=requires_grad, - with_gemm_swizzled_scales=False, - ) - def calibrate(self, tensor: torch.Tensor) -> None: pass # Calibration is no-op