From 2d315e6c3973afaaf62884bbd175d63d35bb336a Mon Sep 17 00:00:00 2001 From: Zhiyi Su Date: Tue, 26 Aug 2025 17:48:16 +0800 Subject: [PATCH 1/2] Adds dst.dtype information in copy_ method of quantized tensors. Signed-off-by: Zhiyi Su --- transformer_engine/pytorch/tensor/quantized_tensor.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 656eda46ca..5621ba0795 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -427,7 +427,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.quantize_(src) else: if isinstance(src, QuantizedTensor): - src = src.dequantize() + src = src.dequantize(dtype=dst.dtype) dst.copy_(src) return None From f06db0d139f7819c8efc058a4e51754f9039a6fd Mon Sep 17 00:00:00 2001 From: ZhiyiDanielSu <35579247+zobeideThePlayer@users.noreply.github.com> Date: Thu, 28 Aug 2025 11:03:49 +0800 Subject: [PATCH 2/2] Update transformer_engine/pytorch/tensor/quantized_tensor.py Co-authored-by: Tim Moon <4406448+timmoon10@users.noreply.github.com> Signed-off-by: ZhiyiDanielSu <35579247+zobeideThePlayer@users.noreply.github.com> --- transformer_engine/pytorch/tensor/quantized_tensor.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/transformer_engine/pytorch/tensor/quantized_tensor.py b/transformer_engine/pytorch/tensor/quantized_tensor.py index 5621ba0795..97930753d7 100644 --- a/transformer_engine/pytorch/tensor/quantized_tensor.py +++ b/transformer_engine/pytorch/tensor/quantized_tensor.py @@ -427,7 +427,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.quantize_(src) else: if isinstance(src, QuantizedTensor): - src = src.dequantize(dtype=dst.dtype) + dtype = dst.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = torch.float32 + src = src.dequantize(dtype=dtype) dst.copy_(src) return None