-
Notifications
You must be signed in to change notification settings - Fork 635
Add NVTE_KEEP_BACKWARD_UNQUANTIZED #2644
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Conversation
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
Greptile OverviewGreptile SummaryAdded Critical Issues:
Implementation:
Confidence Score: 1/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User
participant Recipe
participant Linear
participant BasicLinear
participant Quantize
User->>Recipe: Set NVTE_KEEP_BACKWARD_UNQUANTIZED=1
Recipe->>Recipe: quantize_backward = False
Note over Recipe: DelayedScaling: CRASHES HERE<br/>(assertion at line 220)
User->>Linear: forward(input)
Linear->>Linear: keep_backward_unquantized = True
Linear->>Linear: save_original_input = True
Linear->>Quantize: quantize(input)
Quantize->>Quantize: Check recipe.quantize_forward
Note over Quantize: Potential crash if recipe is None
Quantize-->>Linear: quantized_input (FP8)
Linear->>BasicLinear: forward(quantized_input, weight)
BasicLinear->>BasicLinear: Save high-precision input for backward
BasicLinear-->>Linear: output
User->>Linear: backward(grad_output)
Linear->>BasicLinear: backward(grad_output)
Note over BasicLinear: Uses high-precision saved tensors<br/>Skip quantization in backward
BasicLinear->>BasicLinear: wgrad = grad_output @ input_hp
BasicLinear->>BasicLinear: dgrad = grad_output @ weight_hp
BasicLinear-->>Linear: grad_input (high precision)
Linear-->>User: gradients (BF16/FP32)
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, no comments
|
I'll work on potential unit test breakage. |
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 4 comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, 1 comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
… is used Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 2 comments
| ln_out_return = None | ||
| if return_layernorm_output or return_layernorm_output_gathered: | ||
| ln_out_return = ln_out | ||
| ln_out_hp = ln_out if keep_backward_unquantized else None |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
storing both ln_out (quantized) and ln_out_hp (high precision) doubles the memory footprint for this activation
verify this memory overhead is acceptable for your target models, especially during training with large batch sizes or long sequences
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
4 files reviewed, no comments
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, no comments
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
5 files reviewed, 1 comment
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
12 files reviewed, 2 comments
| # Check if FP8 is enabled | ||
| fp8_enabled = FP8GlobalStateManager.is_fp8_enabled() | ||
| quantize_forward = fp8_enabled and self._quantize_forward | ||
| quantize_backward = fp8_enabled and self._quantize_backward | ||
| quantize_backward = ( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Recipe None crash
FP8GlobalStateManager.get_fp8_recipe() can be None (e.g., if FP8 is enabled but no recipe was set), so ...get_fp8_recipe().quantize_backward will raise an AttributeError. This makes Quantize.forward() crash in that configuration; please guard for None (e.g., default to quantized backward when recipe is unset, or explicitly error with a clearer message).
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 3 comments
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertion prevents using NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe - when env var is set, quantize_backward becomes False, making this assert fail and blocking the entire feature for this recipe type
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
assertion makes LayerNormMLP crash immediately when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 is set - if this module doesn't support the feature, either implement it or handle more gracefully with a clear error message before reaching this point
Additional Comments (1)
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
10 files reviewed, 1 comment
|
|
||
| # Save state for backward pass | ||
| if ctx.requires_grad: | ||
| saved_input = input_ if keep_backward_unquantized else x_local | ||
| saved_weight = self.weight if keep_backward_unquantized else w | ||
| if is_cpu_offload_enabled(): | ||
| mark_activation_offload(x_local) | ||
| ctx.save_for_backward(x_local, w) | ||
| ctx.with_quantized_compute = with_quantized_compute | ||
| mark_activation_offload(saved_input) | ||
| ctx.save_for_backward(saved_input, saved_weight) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unnecessary saved tensors
The forward path now saves saved_input/saved_weight whenever ctx.requires_grad is true, even when weight_requires_grad or input_requires_grad is false. In cases like frozen weights (common for finetune) or when only one side needs grads, this will save extra tensors and can materially increase activation memory. The prior if not weight_requires_grad: saved_input = None / if not input_requires_grad: saved_weight = None logic avoided that.
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
6 files reviewed, 1 comment
Additional Comments (1)
|
|
Currently, without |
| quantize_forward : bool, default = True | ||
| Whether to quantize tensors in the forward pass. | ||
| quantize_backward : bool, default = True | ||
| Whether to quantize tensors in the backward pass. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not sure we need that for the custom recipe, since there we can just specify the quantizers we want, but sure, we can have it to keep the API consistent.
| ) | ||
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is unfortunate.
| ) | ||
| if keep_backward_unquantized: | ||
| # Note, NVTE_KEEP_BACKWARD_UNQUANTIZED is ignored when delayed scaling is used | ||
| save_original_input = True |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We should also make sure that we don't create the columnwise version of the input.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The input_quantizer columnwise is disabled here: https://github.com/NVIDIA/TransformerEngine/pull/2644/changes/BASE..253873a4560b2c2a2c909918cc3ee26500e5b43d#diff-864ad36a21c571fb178499535cfada611df4a82223c9ffbfea872dda39972eaeR335-R342
|
@zianglih Thank you for your contribution!
|
Signed-off-by: Ziang Li <ziangli@umich.edu>
Signed-off-by: Ziang Li <ziangli@umich.edu>
for more information, see https://pre-commit.ci
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 2 comments
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocks NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling - when env var is set, quantize_backward becomes False, triggering this assertion and preventing the feature from working with this recipe type
| assert ( | ||
| not keep_backward_unquantized |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hard crash when NVTE_KEEP_BACKWARD_UNQUANTIZED=1 - LayerNormMLP becomes completely unusable with this env var
|
Hi @zhongbozhu @timmoon10 @ptrendx , thank you so much for reviewing! I have implemented and added the unit test. All new tests passed: |
|
Hi @timmoon10 , @ptrendx,
This design was from @timmoon10 's comment here: #2644 (comment)
Which way do we prefer? Thanks! |
Signed-off-by: Ziang Li <ziangli@umich.edu>
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
16 files reviewed, 3 comments
| assert not ( | ||
| not self.quantize_forward and self.quantize_backward | ||
| ), "Invalid recipe configuration: quantize_backward=True requires quantize_forward=True." | ||
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
blocks NVTE_KEEP_BACKWARD_UNQUANTIZED=1 with DelayedScaling recipe
when env var is set, quantize_backward becomes False, but this assertion requires it to be True - the feature cannot work with this recipe type at all
| assert self.quantize_backward, "Delayed scaling does not support quantize_backward=False." | |
| # Note: DelayedScaling does not support quantize_backward=False yet |
| assert ( | ||
| not keep_backward_unquantized | ||
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
hard crash when NVTE_KEEP_BACKWARD_UNQUANTIZED=1
setting the env var makes LayerNormMLP completely unusable - crashes immediately on first use
| assert ( | |
| not keep_backward_unquantized | |
| ), "NVTE_KEEP_BACKWARD_UNQUANTIZED is not implemented in LayerNormMLP" | |
| if keep_backward_unquantized: | |
| raise NotImplementedError( | |
| "NVTE_KEEP_BACKWARD_UNQUANTIZED is not yet implemented in LayerNormMLP" | |
| ) |
| # Recipe quantize overrides | ||
| if FP8GlobalStateManager.get_fp8_recipe() is not None: | ||
| quantize_forward = ( | ||
| quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward | ||
| ) | ||
| quantize_backward = ( | ||
| quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
get_fp8_recipe() returns None when FP8 is enabled but no recipe set
calling .quantize_backward on None will crash with AttributeError
| # Recipe quantize overrides | |
| if FP8GlobalStateManager.get_fp8_recipe() is not None: | |
| quantize_forward = ( | |
| quantize_forward and FP8GlobalStateManager.get_fp8_recipe().quantize_forward | |
| ) | |
| quantize_backward = ( | |
| quantize_backward and FP8GlobalStateManager.get_fp8_recipe().quantize_backward | |
| # Recipe quantize overrides | |
| recipe = FP8GlobalStateManager.get_fp8_recipe() | |
| if recipe is not None: | |
| quantize_forward = quantize_forward and recipe.quantize_forward | |
| quantize_backward = quantize_backward and recipe.quantize_backward |
|
Full unit tests results, with the newly added |
Description
@HumansAnd
Add an NVTE_KEEP_BACKWARD_UNQUANTIZED env var for quantized fprop + high precision wgrad & dgrad.
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: