-
Notifications
You must be signed in to change notification settings - Fork 30
Description
Adafactor._approx_sq_grad() missing epsilon guard causes silent NaN/Inf corruption
Repository: microsoft/LLM2CLIP
Description
The _approx_sq_grad() method in the Adafactor optimizer (llm2clip/training/fp16.py) divides by exp_avg_sq_row.mean(dim=-1) without an epsilon guard on the denominator:
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output):
r_factor = (
(exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1).unsqueeze(-1))
# ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
# Can be 0.0 → produces Inf, then NaN via rsqrt
.rsqrt_()
.unsqueeze(-1)
)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)When all gradient values for a parameter are zero (dead neurons, fully masked parameters, or early training with sparse gradients), the exponential moving average (exp_avg_sq_row) can be all zeros or underflow to zero. In that case, mean() returns 0.0, the division produces Inf, and rsqrt_() propagates NaN — silently corrupting the optimizer state for all subsequent training steps.
For comparison, the HuggingFace Transformers implementation of Adafactor adds an epsilon term to prevent this:
# HuggingFace version (safe):
r_factor = (exp_avg_sq_row / exp_avg_sq_row.mean(dim=-1, keepdim=True)).rsqrt_().unsqueeze(-1)
# with exp_avg_sq initialized with fill_value=eps, preventing zeroSuggested fix
Add a small epsilon to the denominator:
def _approx_sq_grad(self, exp_avg_sq_row, exp_avg_sq_col, output):
row_mean = exp_avg_sq_row.mean(dim=-1).unsqueeze(-1)
r_factor = (exp_avg_sq_row / (row_mean + 1e-30)).rsqrt_().unsqueeze(-1)
c_factor = exp_avg_sq_col.unsqueeze(-2).rsqrt()
torch.mul(r_factor, c_factor, out=output)Or initialize exp_avg_sq_row with fill_value=eps (as HuggingFace does) instead of zeros.
How this was found
This was identified via static analysis using a3-python, which flagged the unguarded division as a DSE-confirmed reachable DIV_ZERO.