diff --git a/transformer_engine/pytorch/quantized_tensor.py b/transformer_engine/pytorch/quantized_tensor.py index 0a6ad61ff0..3aefda38e2 100644 --- a/transformer_engine/pytorch/quantized_tensor.py +++ b/transformer_engine/pytorch/quantized_tensor.py @@ -466,7 +466,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): dst.quantize_(src) else: if isinstance(src, QuantizedTensor): - src = src.dequantize() + dtype = dst.dtype + if dtype not in (torch.float32, torch.float16, torch.bfloat16): + dtype = torch.float32 + src = src.dequantize(dtype=dtype) dst.copy_(src) return None