Skip to content
5 changes: 4 additions & 1 deletion transformer_engine/pytorch/quantized_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -466,7 +466,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
dst.quantize_(src)
else:
if isinstance(src, QuantizedTensor):
src = src.dequantize()
dtype = dst.dtype
if dtype not in (torch.float32, torch.float16, torch.bfloat16):
dtype = torch.float32
src = src.dequantize(dtype=dtype)
dst.copy_(src)
return None

Expand Down
Loading